CUTLASS 3.0.0 (#786)

* CUTLASS 3.0.0
This commit is contained in:
Vijay Thakkar
2023-01-23 17:55:28 -08:00
committed by GitHub
parent 66d9cddc83
commit 277bd6e537
377 changed files with 76396 additions and 1186 deletions

View File

@ -1,5 +1,18 @@
# NVIDIA CUTLASS Changelog
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md).
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md).
* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture).
* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle).
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.

View File

@ -5,33 +5,61 @@ message: >-
following metadata.
type: software
authors:
- given-names: Andrew
email: akerr@nvidia.com
family-names: Kerr
- given-names: Vijay
family-names: Thakkar
email: vithakkar@nvidia.com
affiliation: NVIDIA
- given-names: Pradeep
family-names: Ramani
email: prramani@nvidia.com
affiliation: NVIDIA
- given-names: Cris
family-names: Cecka
email: ccecka@nvidia.com
affiliation: NVIDIA
- given-names: Aniket
family-names: Shivam
email: ashivam@nvidia.com
affiliation: NVIDIA
- given-names: Honghao
family-names: Lu
email: honghaol@nvidia.com
affiliation: NVIDIA
- given-names: Ethan
family-names: Yan
email: etyan@nvidia.com
affiliation: NVIDIA
- given-names: Jack
family-names: Kosaian
email: jkosaian@nvidia.com
affiliation: NVIDIA
- given-names: Mark
family-names: Hoemmen
email: mhoemmen@nvidia.com
affiliation: NVIDIA
- given-names: Haicheng
family-names: Wu
affiliation: NVIDIA
email: haichengw@nvidia.com
- given-names: Manish
family-names: Gupta
affiliation: Google
email: manigupta@google.com
- given-names: Dustyn
family-names: Blasig
email: dblasig@nvidia.com
affiliation: NVIDIA
- given-names: Pradeep
family-names: Ramini
email: prramani@nvidia.com
- given-names: Andrew
family-names: Kerr
email: akerr@nvidia.com
affiliation: NVIDIA
- given-names: Matt
family-names: Nicely
email: mnicely@nvidia.com
affiliation: NVIDIA
- given-names: Duane
family-names: Merrill
email: dumerrill@nvidia.com
affiliation: NVIDIA
- given-names: Aniket
family-names: Shivam
email: ashivam@nvidia.com
- given-names: Dustyn
family-names: Blasig
email: dblasig@nvidia.com
affiliation: NVIDIA
- given-names: Fengqi
family-names: Qiao
email: fqiao@nvidia.com
affiliation: NVIDIA
- given-names: Piotr
family-names: Majcher
@ -49,10 +77,12 @@ authors:
family-names: Wang
email: jinw@nvidia.com
affiliation: NVIDIA
- given-names: Matt
family-names: Nicely
email: mnicely@nvidia.com
affiliation: NVIDIA
- given-names: Manish
family-names: Gupta
affiliation: Google
email: manigupta@google.com
repository-code: 'https://github.com/NVIDIA/cutlass'
abstract: >-
CUTLASS is a collection of CUDA C++ template
@ -71,12 +101,12 @@ abstract: >-
flexibility simplifies their use as building blocks
within custom kernels and applications.
keywords:
- 'cutlass, tensor cores, cuda'
- 'cutlass, tensor cores, cuda, cute, nvidia, gpu, linear algebra, matrix computations'
license: BSD-3-Clause
license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt
version: '2.11.0'
date-released: '2022-11-19'
license-url: https://github.com/NVIDIA/cutlass/blob/v3.0.0/LICENSE.txt
version: '3.0.0'
date-released: '2023-01-23'
identifiers:
- type: url
value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0"
description: The GitHub release URL of tag 2.11.0
value: "https://github.com/NVIDIA/cutlass/tree/v3.0.0"
description: The GitHub release URL of tag 3.0.0

View File

@ -26,7 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
if(cutlass_LOADED)
# If CUTLASS has been previously fetched and loaded, don't do it again.
@ -39,35 +39,40 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 2.11.0 LANGUAGES CXX)
project(CUTLASS VERSION 3.0.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 10.2)
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.")
elseif (CUDA_VERSION VERSION_LESS 11.0)
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.")
if (CUDA_VERSION VERSION_LESS 11.3)
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.")
elseif (CUDA_VERSION VERSION_LESS 11.4)
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.")
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5)
message(FATAL_ERROR "GCC version must be at least 7.5!")
endif()
if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
find_package(Doxygen QUIET)
#
# CUTLASS 2.x requires C++11
# CUTLASS 3.x requires C++17
#
if (NOT IMPLICIT_CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
if(CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_STANDARD 11)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
else()
if (NOT IMPLICIT_CMAKE_CXX_STANDARD)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
endif()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17)
endif()
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
endif()
@ -107,29 +112,14 @@ if (CUTLASS_ENABLE_TESTS)
endif()
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (NOT CUDA_VERSION VERSION_LESS 7.5)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87)
endif()
if (NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90)
endif()
if (NOT CUDA_VERSION VERSION_LESS 9.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70)
endif()
if (NOT CUDA_VERSION VERSION_LESS 9.2)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72)
endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90)
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
@ -271,6 +261,7 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
@ -288,6 +279,15 @@ if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
endif()
endif()
if (CUTLASS_ENABLE_OPENMP_TESTS)
find_package(OpenMP)
if(OpenMP_CXX_FOUND)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS})
else()
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
endif()
endif()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
@ -313,10 +313,6 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
endif()
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
# There are numerous Clang versions that can work with each CUDA toolkit and the
# the checks are not very useful so we are turning them off and using testing to
# ensure the various combinations work properly.
@ -341,6 +337,7 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
link_libraries(nvidia::cudart)
link_libraries(nvidia::cuda_driver)
endif()
# Support for 128-bit integers if using NVIDIA C++ compiler
@ -530,6 +527,8 @@ target_include_directories(
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
)
install(

View File

@ -7,63 +7,77 @@
This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Cris Cecka
Vijay Thakkar
Aniket Shivam
Honghao Lu
Ethan Yan
Zhaodong Chen
Jack Kosaian
Yujia Zhai
Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Chinmay Talegaonkar
Shang Zhang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner
Manikandan Ananth
Vijay Thakkar<br />
Pradeep Ramani<br />
Cris Cecka<br />
Aniket Shivam<br />
Jack Kosaian<br />
Mark Hoemmen<br />
Honghao Lu<br />
Ethan Yan<br />
Haicheng Wu<br />
Andrew Kerr<br />
Dustyn Blasig<br />
Fengqi Qiao<br />
Duane Merrill<br />
Yujia Zhai<br />
Shang Zhang<br />
Piotr Majcher<br />
Paul Springer<br />
Markus Hohnerbach<br />
Jin Wang<br />
Aditya Atluri<br />
## CuTe
Cris Cecka<br />
Vijay Thakkar<br />
## CUTLASS Product Manager
Matthew Nicely
Matthew Nicely<br />
## Former CUTLASS Developers
Manish Gupta<br />
Naila Farooqui<br />
David Tanner<br />
Manikandan Ananth<br />
Zhaodong Chen<br />
Chinmay Talegaonkar<br />
## CONTRIBUTORS
Timothy Costa
Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Duane Merrill
Kevin Siu
Markus Tavenrath
John Tran
Vicki Wang
Junkai Wu
Fung Xie
Albert Xu
Jack Yang
Xiuxia Zhang
Nick Zhao
Timothy Costa<br />
Julien Demouth<br />
Brian Fahs<br />
Michael Garland<br />
Michael Goldfarb<br />
Mostafa Hagog<br />
Fei Hu<br />
Alan Kaatz<br />
Tina Li<br />
Timmy Liu<br />
Wei Liu<br />
Duane Merrill<br />
Kevin Siu<br />
Markus Tavenrath<br />
John Tran<br />
Vicki Wang<br />
Junkai Wu<br />
Fung Xie<br />
Albert Xu<br />
Yang Xu<br />
Jack Yang<br />
Scott Yokim<br />
Xiuxia Zhang<br />
Nick Zhao<br />
## ACKNOWLEDGEMENTS
Girish Bharambe
Luke Durant
Olivier Giroux
Stephen Jones
Rishkul Kulkarni
Bryce Lelbach
Joel McCormack
Kyrylo Perelygin
Girish Bharambe<br />
Luke Durant<br />
Carter Edwards<br />
Olivier Giroux<br />
Stephen Jones<br />
Rishkul Kulkarni<br />
Bryce Lelbach<br />
Joel McCormack<br />
Kyrylo Perelygin<br />
Sean Treichler<br />

180
README.md
View File

@ -1,18 +1,18 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.11
# CUTLASS 3.0
_CUTLASS 2.11 - November 2022_
_CUTLASS 3.0 - January 2023_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) and related computations at all levels
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
these "moving parts" into reusable, modular software components abstracted by C++ template
classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
resulting flexibility simplifies their use as building blocks within custom kernels
and applications.
classes. Primitives for different levels of a conceptual parallelization hierarchy
can be specialized and tuned via custom tiling sizes, data types,
and other algorithmic policy. The resulting flexibility simplifies their use
as building blocks within custom kernels and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
@ -21,60 +21,75 @@ point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32),
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, and Ampere architectures.
CUTLASS implements high-performance Convolution via the implicit GEMM algorithm.
Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of
CUTLASS's modular GEMM pipeline.
This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below.
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.11
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
CUTLASS 2.11 is an update to CUTLASS adding:
- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm).
- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv).
- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop).
- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM.
- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release:
- Maxwell and Pascal GPU architectures
- Ubuntu 16.04
- CUDA 10.2
- C++ 11
- **Future requirement announcement:** CUTLASS plans to add the following requirements in the next major release:
- Minimum C++ standard - C++17
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
# What's New in CUTLASS 3.0
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
for kernel procedural names.
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
## New architecture, compiler, and CUDA Toolkit requirements
Minimum requirements:
- Architecture: Volta
- Compiler: Must support at least C++17
- CUDA Toolkit version: 11.4
CUTLASS 3.0 *removes support* for the following:
- Maxwell and Pascal GPU architectures
- Ubuntu 16.04
- CUDA 10.2
- C++ language versions less than 17.
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
# Performance
<p align="center"><img src=/media/images/cutlass-2.8-gemm-performance.png></p>
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit performance comparable to cuBLAS for scalar GEMM
they exhibit peak performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/),
an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/),
an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/),
and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/)
compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's
for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture),
an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture),
an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture),
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
<p align="center"><img src=/media/images/cutlass-2.9-implicit-gemm-performance.png></p>
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
@ -83,39 +98,48 @@ as shown in the above figure. Tensor Core operations are still implemented usin
# Compatibility
CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.x.
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
## Operating Systems
We have tested the following environments.
|**Operating System** | **Compiler** |
|-----------------|----------|
| Windows 10 | Microsoft Visual Studio 2015|
| | Microsoft Visual Studio 2017|
| | Microsoft Visual Studio 2019|
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
## Hardware
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|---|
|NVIDIA Tesla V100|7.0|9.2|10.1|
|NVIDIA TitanV|7.0|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|NVIDIA Tesla T4|7.5|10.0|10.2|
|NVIDIA A100|8.0|11.0|11.0|
|NVIDIA A10 |8.6|11.1|11.1|
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0|
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|---|---|---|
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|NVIDIA TitanV |7.0|11.4|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4|
|NVIDIA T4 |7.5|11.4|
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|NVIDIA A10 |8.6|11.4|
|NVIDIA GeForce RTX 3090 |8.6|11.4|
|NVIDIA GeForce RTX 4090 |8.9|11.8|
|NVIDIA L40 |8.9|11.8|
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
## Target Architecture
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
```
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
```
Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures.
# Documentation
@ -125,7 +149,9 @@ CUTLASS is described in the following documents and the accompanying
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](media/docs/terminology.md) - describes terms used in the code
@ -161,7 +187,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0.
To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
@ -224,6 +251,23 @@ include/ # client applications should target this directory
transform/ # code specialized for layout, type, and domain transformations
* # core vocabulary types, containers, and basic numeric operations
cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy
algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples
arch/ # Bare bones PTX wrapper structs for copy and math instructions
atom/ # Meta-information either link to or built from arch/ operators
mma_atom.hpp # cute::Mma_Atom and cute::TiledMma
copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy
*sm*.hpp # Arch specific meta-information for copy and math operations
* # Core library types such as Shape, Stride, Layout, Tensor, and associated operations
```
### CUTLASS SDK Examples
@ -269,7 +313,7 @@ By default, only one tile size is instantiated for each data type, math instruct
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
Beware, this results in *thousands* of kernels and long build times.
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all
...
$ make cutlass_profiler -j16
```

View File

@ -40,7 +40,7 @@ elseif(NOT TARGET cublas)
find_path(
_CUBLAS_INCLUDE_DIR
NAMES cublas.h
NAMES cublas_v2.h
HINTS
${CUBLAS_INCLUDE_PATH}
ENV CUBLAS_INCLUDE_PATH

View File

@ -45,5 +45,6 @@ target_link_libraries(
PRIVATE
cutlass_lib
cutlass_tools_util_includes
cuda
)

View File

@ -45,5 +45,6 @@ target_link_libraries(
PRIVATE
cutlass_lib
cutlass_tools_util_includes
cuda
)

View File

@ -35,7 +35,7 @@
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
+ lightweight full reduction kernel (ApplyFinalReduction)
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
*/
#pragma once
@ -77,7 +77,7 @@ template <
typename ElementLayernormCompute_,
typename ElementOutput,
typename ThreadblockShape_,
bool IsShiftedVariance_ = false
bool IsShiftedVariance_ = false
>
class ApplyFinalReduction {
public:
@ -91,7 +91,7 @@ public:
using Layout = cutlass::layout::RowMajor;
using TensorVariance = TensorRef<ElementVariance, Layout>;
using TensorMean = TensorRef<ElementMean, Layout>;
using TensorMean = TensorRef<ElementMean, Layout>;
static bool const kIsShiftedVariance = IsShiftedVariance_;
@ -463,7 +463,7 @@ public:
for (int rid = 0; rid < kRowIterations; ++rid) {
int row_step_offset = rid * kDeltaRow;
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
bool is_load = (row_offset < extent_.row());
bool is_load = (row_offset < extent_.row());
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
}
@ -504,9 +504,9 @@ public:
using Minus = cutlass::minus<ElementLayernormCompute>;
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
Minus minus;
Mul mul;
Exp exponential;
[[maybe_unused]] Minus minus;
[[maybe_unused]] Mul mul;
[[maybe_unused]] Exp exponential;
LayernormFragment result;
@ -605,7 +605,7 @@ private:
CUTLASS_DEVICE
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
ConvertShiftK convert_shift_k;
ConvertShiftK convert_shift_k;
ElementOutput shift_k_val;
// Computes the address to load shift_k element
@ -614,7 +614,7 @@ private:
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
// Converts data type to return
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
return converted_shift_k_val;
}
@ -689,7 +689,7 @@ public:
//
// Type definitions
//
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutOutput_, cutlass::layout::ColumnMajor>::value;
static bool const kIsShiftedVariance = IsShiftedVariance_;
@ -704,14 +704,14 @@ public:
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;
// These are mandatory layouts and data types
// These are mandatory layouts and data types
// that are inheritated from pre-defined params
using LayoutSumSqr = LayoutInputScaleBias;
using LayoutSum = LayoutInputScaleBias;
using ElementMean = ElementInputScaleBias;
using ElementVariance = ElementInputScaleBias;
using ElementVariance = ElementInputScaleBias;
///////////////////////////////////////////////////////////////////////////////////////////////
@ -720,7 +720,7 @@ public:
using LayoutInputA1 = LayoutOutput_;
using LayoutInputB1 = LayoutOutput_;
using LayoutOutputC0 = LayoutOutput_;
using LayoutOutputC1 = LayoutOutput_;
using LayoutOutputC1 = LayoutOutput_;
using ElementInputA0 = ElementInputA0_;
using ElementInputB0 = ElementInputB0_;
@ -747,7 +747,7 @@ public:
static int const kStages1 = Stages1;
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
///////////////////////////////////////////////////////////////////////////////////////////////
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<

View File

@ -180,9 +180,9 @@ public:
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
Arguments():
problem_count(0),
threadblock_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
@ -201,7 +201,7 @@ public:
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
Arguments(
GemmCoord *problem_sizes0,
GemmCoord *problem_sizes1,
int problem_count,
@ -219,7 +219,7 @@ public:
typename LayoutO::Stride::LongIndex *ldo,
bool causal,
GemmCoord *host_problem_sizes=nullptr
):
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
@ -311,7 +311,7 @@ public:
ldv(args.ldv),
ldo(args.ldo),
causal(args.causal)
{
{
}
@ -464,7 +464,7 @@ public:
void operator()(Params const &params, SharedStorage &shared_storage) {
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& si = shared_storage.after_mm0.si;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
ProblemVisitor problem_visitor(

View File

@ -481,7 +481,7 @@ struct AttentionKernel {
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& si = shared_storage.after_mm0.si;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");

View File

@ -384,7 +384,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
// but not supported as it worsens perf: older gpus < sm80 don't
// support async tranfers and have to waste registers
CUTLASS_DEVICE
bool set_prologue_done(bool value) {}
void set_prologue_done(bool value) {}
CUTLASS_DEVICE
static void prologue(
typename Base::SharedStorage& shared_storage,
@ -695,7 +695,7 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
}
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
void set_prologue_done(bool value) {
prologue_done_ = value;
}

View File

@ -34,7 +34,7 @@
"classic data-parallel" and "Split-K" decompositions.
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
Requires NVIDIA Ampere or newer device (SM80+).

View File

@ -0,0 +1,463 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and
convert them implicitly to TF32.
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
Examples:
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = float; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = float; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TilesShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(5120), n(4096), k(4096),
alpha(1.f), beta(0.f),
iterations(1000)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "48_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
block_A.reset(options.m * options.k);
block_B.reset(options.k * options.n);
block_C.reset(options.m * options.n);
block_D.reset(options.m * options.n);
block_ref_D.reset(options.m * options.n);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k},
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}
};
return arguments;
}
bool verify(const Options &options) {
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{options.m, options.n, options.k},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,35 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
48_hopper_warp_specialized_gemm
48_hopper_warp_specialized_gemm.cu
)

View File

@ -0,0 +1,522 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper GEMM example leveraging collective operation builders.
This example showcases the use of CUTLASS's CollectiveBuilder to easily construct performant kernels
targetting the NVIDIA Hopper architecture.
Background and motivation
-------------------------
CUTLASS kernels are highly parameterizable via template parameters. To ease the selection of template
parameters, CUTLASS 2 leveraged DefaultGemmConfigurations. Given a small set of parameters, such as
the data types of operands and the compute capability of the GPU, DefaultGemmConfigurations defined sensible
defaults for the many other parameters to the kernel (e.g., warp shape, stage count).
However, DefaultGemmConfigurations leave multiple opportunities for improvement, which are addressed
in CUTLASS 3:
(1) DefaultGemmConfigurations do not allow one to use a more-performant set of parameters without
specifying every parameter. For example, the DefaultGemmConfigurations for GEMMs targetting
Ampere specify that three pipeline stages should be used regardless of the sizes of operands.
If one wished to increase this value, one would also need to specify all other template parameters.
This leaves a gap between a high-level ease-of-use interface and a lower-level detailed interface.
(2) A new DefaultGemmConfiguration was required for each combination of operand types, GPU architecture,
and operation type (e.g., Tensor Core or SIMT). This led to increased code size to cover each unique
configuration and a lack of extensibility from one DefaultGemmConfiguration to another.
Alongside these opportunities for improvement, the Hopper architecture offers new features that increase
the number of valid configurations of a kernel. In addition to the many template parameters already available
in CUTLASS 2 kernels, CUTLASS 3 kernels targetting Hopper also have various scheduling modes to select from that control:
(1) how data is to be loaded (e.g., using the Hopper TMA feature or Ampere cp.async)
(2) how work is to be divided among warps in a thread block (e.g., whether to use "warp specialization")
(3) whether persistent thread blocks should be used
This increased configuration space further motivates rethinking DefaultGemmConfigurations.
Introduction to the CollectiveBuilder
-------------------------------------
CUTLASS 3 introduces the CollectiveBuilder to further ease the process of selecting template parameters
for kernels targetting Hopper. Similar to the DefaultGemmConfigurations used in CUTLASS 2, the CollectiveBuilder
takes in a small set of template parameters (e.g., the data types of operands A and B). It then automatically
determines the data loading strategy to use depending on whether the Hopper TMA feature can be used with the provided
parameters. If one does not indicate a particular scheduling policy or stage count to use (by using `Auto` template
parameters), the CollectiveBuilder will also automatically select these.
Unlike DefaultGemmConfigurations a parital specialization of the CollectiveBuilder is not needed for many
configurations of operand types. Instead the CollectiveBuilder "builds" a configuration based on generic
properties of the specified operands, layouts, and other parameters. For example, when the stage count
is set to `Auto`, the CollectiveBuilder may automatically calculate the maximum number of stages that
will fit in shared memory given the types of operands and the thread block shape, rather than simply using
a single default value.
Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this
manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be
a convenience interface.
Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not
a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject
to change in future CUTLASS releases -- do not rely on `Auto` if you require a specific scheduling policy and/or
stage count to be used.
Details of this example
-----------------------
This example walks through the use of the CollectiveBuilder with various schedules and stage counts specified.
This example also illustrates how CUTLASS 3 GEMMs targetting Hopper automatically support batched GEMMs by simply
extending the problem size with an additional tensor rank.
Example usage:
$ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \
--m=2048 --n=2048 --k=2048 --l=2
*/
#include <iostream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help;
bool error;
int m, n, k, l;
float alpha, beta;
Options():
help(false),
error(false),
m(2048), n(2048), k(2048), l(1),
alpha(1.f), beta(0.f)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "49_hopper_gemm_schedules_with_collective_builder\n\n"
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
<< " performant kernels targetting NVIDIA's Hopper architecture.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
// operation builders by specializing the GEMM only on the kernel schedule it will use and the
// number of pipeline stages.
//
// For either option, one can use a special `Auto` type that tells the CollectiveBuilder
// to select an appropriate value on its own. The CollectiveBuilder will attempt to select
// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore,
// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
// -- do not rely on `Auto` if you require a specific scheduling policy.
template <
// Type of kernel schedule to generate
class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
// Number of pipeline stages to use
class StageCountType = cutlass::gemm::collective::StageCountAuto
>
struct ExampleRunner {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
static constexpr int kAlignmentA = 8;
static constexpr int kAlignmentB = 8;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, kAlignmentA,
cutlass::half_t, LayoutB, kAlignmentB,
float,
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
StageCountType,
KernelScheduleType
>::CollectiveOp;
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideA>());
using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B<StrideB>());
using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideC>());
using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideD>());
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed = 0;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
//
// Methods
//
bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
auto [M, N, K, L] = problem_size;
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
cutlass::reference::device::GemmComplex(
{M, N, K},
typename Gemm::EpilogueOutputOp::ElementCompute(alpha),
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
typename Gemm::EpilogueOutputOp::ElementCompute(beta),
ref_C,
ref_D,
typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const ProblemShapeType& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
initialize(problem_size);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
hw_info
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Verify that the result is correct
bool passed = verify(problem_size, options.alpha, options.beta);
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
}
return passed;
}
};
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, bool passed) {
std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) {
std::cout
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
bool passed;
// This first example constructs a GEMM using the default schedule and stage count provided by
// the CollectiveBuilder. The scheduling policy that is expected to be most performant will be
// selected and the maximum number of stages that can fit in shared memory will be selected.
//
// This example is equivalent to declaring
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, cutlass::gemm::collective::StageCountAuto>
// Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and
// stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change
// -- do not rely on `Auto` if you require a specific scheduling policy.
ExampleRunner<> auto_schedule_auto_stage_runner;
passed = auto_schedule_auto_stage_runner.run(options, hw_info);
print_result("Automatically-selected schedule and stage count", passed);
// One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto
// with the number of stages to use (5 in this case).
ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, _5> auto_schedule_5_stage_runner;
passed = auto_schedule_5_stage_runner.run(options, hw_info);
print_result("Automatically-selected schedule with 5 stages", passed);
// One can also override the scheduling policy to use. In this case, use the KernelTma scheduling
// policy, which specifies that the Hopper TMA feature should be used.
ExampleRunner<cutlass::gemm::KernelTma> tma_schedule_auto_stage_runner;
passed = tma_schedule_auto_stage_runner.run(options, hw_info);
print_result("TMA schedule with automatically-selected stage count", passed);
// Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized
// scheduling policy.
//
// Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy
// used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto>. However, this default is subject to
// change in future releases -- do not rely on `Auto` if you require a specific scheduling policy.
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized> ws_schedule_auto_stage_runner;
passed = ws_schedule_auto_stage_runner.run(options, hw_info);
print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed);
// Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
// scheduling policy, leveraging persistent thread blocks.
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecializedPersistent> ws_persistent_schedule_auto_stage_runner;
passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info);
print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,35 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
49_hopper_gemm_schedules_with_collective_builder
49_hopper_gemm_schedules_with_collective_builder.cu
)

View File

@ -0,0 +1,529 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper GEMM example to create a GEMM kernel with custom Collectives
The following example shows how to assemble a custom GEMM kernel that spells out the Collectives
directly instead of using a builder and, in the process, instance a more efficient Epilogue
(from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue.
The GemmUniversal API takes 3 main template arguments:
(1) the problem shape / extents
(2) the collective mainloop type
(3) the collective epilogue type
While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is
possible to build a custom collective mainloop directly as well. Furthermore, since epilogues
do not yet have a builder interface, this example shows how to instantiate a more-efficient
epilogue alongside the collective mainloop.
Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set
of trade-offs. So it is recommended that users look at the options available under
cutlass/epilogue/collective and evaluate for their particular scenario.
Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples
present in `test/unit/cute` to famialiarize with the basics of CuTe.
Examples:
$ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool error;
int m, n, k, l;
int alpha, beta;
Options():
help(false),
error(false),
m(2048), n(2048), k(2048), l(1),
alpha(1), beta(0)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1);
cmd.get_cmd_line_argument("beta", beta, 0);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "50_hopper_gemm_with_vectorized_epilogue\n\n"
<< "Hopper GEMM Example with Epilogue Swizzle.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --alpha=<s32> Epilogue scalar alpha\n"
<< " --beta=<s32> Epilogue scalar beta\n\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
// Wrapper to run and verify a GEMM.
template <
class Gemm
>
struct ExampleRunner {
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using LayoutD = typename Gemm::LayoutD;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementAcc = typename Gemm::ElementAccumulator;
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed = 0;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
//
// Methods
//
bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) {
auto [M, N, K, L] = problem_size;
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));
cutlass::reference::device::GemmComplex(
{M, N, K},
ElementCompute(alpha),
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
ElementCompute(beta),
ref_C,
ref_D,
ElementAccumulator(0),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const ProblemShapeType& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
initialize(problem_size);
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
hw_info
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Verify that the result is correct
bool passed = verify(problem_size, options.alpha, options.beta);
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
}
return passed;
}
};
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) {
std::cout
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
bool passed;
// Problem configuration
using ElementA = int8_t;
using ElementB = int8_t;
using ElementAcc = int32_t;
using ElementOutput = int8_t;
// Note : Only TN WGMMA Gemm is supported currently in 3.0
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
// Tiling configuration selection
using TileShape = Shape<_128,_64,_128>;
// Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks
using ClusterShape = Shape<_1,_2,_1>;
//
// Assembling the CollectiveMainloop type
//
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
constexpr int PipelineStages = 8;
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
// Note : This requires / assumes the tensors to be 16B aligned
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
cutlass::gemm::KernelTmaWarpSpecialized>;
// TN => K Major for both A & B
static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K;
static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K;
// We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA)
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>()));
// A loads can be optimized with multicast if cluster-n > 1
using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1,
cute::SM90_TMA_LOAD,
cute::SM90_TMA_LOAD_MULTICAST>::type;
// B loads can be optimized with multicast if cluster-m > 1
using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1,
cute::SM90_TMA_LOAD,
cute::SM90_TMA_LOAD_MULTICAST>::type;
using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{}))
>());
using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{}))
>());
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
void, // Does not need a SmemCopyAtom, since A is read directly from SMEM
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void, // Does not need a SmemCopyAtom, since B is read directly from SMEM
cute::identity
>;
//
// Assembling the Collective Epilogue Type
//
// Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension
using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>,
Stride<Stride< _1,_2048>,_32>>;
// 128 threads loading 16 elements each (to get vectorized global stores)
using TileShapeS2R = Shape<_128,_16>;
// Layout to ensure bank-conflict free loads & stores
using SmemLayout = ComposedLayout<
Swizzle<3,4,3>,
smem_ptr_flag_bits<sizeof_bits<ElementAcc>::value>,
PreSwizzleLayout>;
// Tiled copy from Smem to Registers
// Note : CuTe will vectorize this copy if the tiling + swizzling above were right
using TiledCopyS2R = TiledCopy<
Copy_Atom<DefaultCopy, ElementAcc>,
Layout< Shape<_128,_16>,
Stride<_16,_1>>,
TileShapeS2R>;
using Epilogue = cutlass::epilogue::collective::Epilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
SmemLayout,
Copy_Atom<DefaultCopy, ElementAcc>,
TiledCopyS2R,
Copy_Atom<DefaultCopy, ElementOutput>>;
//
// Assembling the GemmKernel
//
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
Epilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
ExampleRunner<Gemm> runner;
passed = runner.run(options, hw_info);
std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,35 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
50_hopper_gemm_with_epilogue_swizzle
50_hopper_gemm_with_epilogue_swizzle.cu
)

View File

@ -54,12 +54,14 @@ function(cutlass_example_add_executable NAME)
CUTLASS
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
cuda
)
target_include_directories(
${NAME}
PRIVATE
${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR}
${CUTLASS_EXAMPLES_UTILS_DIR}
)
install(
@ -118,6 +120,7 @@ foreach(EXAMPLE
36_gather_scatter_fusion
37_gemm_layernorm_gemm_fusion
38_syr2k_grouped
cute
39_gemm_permute
41_fused_multi_head_attention
42_ampere_tensorop_group_conv
@ -125,6 +128,9 @@ foreach(EXAMPLE
45_dual_gemm
46_depthwise_simt_conv2dfprop
47_ampere_gemm_universal_streamk
48_hopper_warp_specialized_gemm
49_hopper_gemm_schedules_with_collective_builder
50_hopper_gemm_with_epilogue_swizzle
)
add_subdirectory(${EXAMPLE})

View File

@ -0,0 +1,30 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(tutorial)

View File

@ -0,0 +1,34 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
sgemm_nt_1
sgemm_nt_1.cu
)

View File

@ -0,0 +1,426 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "cutlass/util/helper_cuda.hpp"
template <class MShape, class NShape, class KShape,
class TA, class AStride, class ABlockLayout, class AThreadLayout,
class TB, class BStride, class BBlockLayout, class BThreadLayout,
class TC, class CStride, class CBlockLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(MShape M, NShape N, KShape K,
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
TC * C, CStride dC, CBlockLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
// Shared memory buffers
__shared__ TA smemA[cosize_v<ABlockLayout>];
__shared__ TB smemB[cosize_v<BBlockLayout>];
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
// Represent the full tensors
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K)
auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
// Get the appropriate blocks for this thread block --
// potential for thread block locality
auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K)
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB
// Default is a raked partition, but can be changed with Step<X,Y> parameter
auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
//
// Define C accumulators and A/B partitioning
//
// TUTORIAL: Example of partitioning via projections of tC
// Partition sA (M,K) by the rows of tC
auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// Allocate the accumulators -- same size as the projected data
auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N)
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print("mA\n");
print(mA.shape()); print("\n"); print(mA.stride());
print("\n\ngA\n");
print(gA.shape()); print("\n"); print(gA.stride());
print("\n\ntAgA\n");
print(tAgA.shape()); print("\n"); print(tAgA.stride());
print("\n\nsA\n");
print(sA.shape()); print("\n"); print(sA.stride());
print("\n\ntAsA\n");
print(tAsA.shape()); print("\n"); print(tAsA.stride());
print("\n\n");
}
#endif
#if 0
if(thread0()) {
print("mB\n");
print(mB.shape()); print("\n"); print(mB.stride());
print("\n\ngB\n");
print(gB.shape()); print("\n"); print(gB.stride());
print("\n\ntBgB\n");
print(tBgB.shape()); print("\n"); print(tBgB.stride());
print("\n\nsB\n");
print(sB.shape()); print("\n"); print(sB.stride());
print("\n\ntBsB\n");
print(tBsB.shape()); print("\n"); print(tBsB.stride());
print("\n\n");
}
#endif
#if 0
if(thread0()) {
print("mC\n");
print(mC.shape()); print("\n"); print(mC.stride());
print("\n\ngC\n");
print(gC.shape()); print("\n"); print(gC.stride());
print("\n\ntCsA\n");
print(tCsA.shape()); print("\n"); print(tCsA.stride());
print("\n\ntCsB\n");
print(tCsB.shape()); print("\n"); print(tCsB.stride());
print("\n\ntCgC\n");
print(tCgC.shape()); print("\n"); print(tCgC.stride());
print("\n\ntCrC\n");
print(tCrC.shape()); print("\n"); print(tCrC.stride());
print("\n\n");
}
#endif
#if 1
// TUTORIAL: Example of a very simple compute loop
// Data is read from global to shared memory via the tA|tB partitioning
// gemm(.) operates on the shared memory directly via the tC partitioning
auto k_max = size<2>(tAgA);
for (int k = 0; k < k_max; ++k)
{
// Copy gmem to smem
copy(tAgA(_,_,k), tAsA);
copy(tBgB(_,_,k), tBsB);
// In case copy uses cp.async, make sure that the cp.async
// instructions are ordered with respect to other cp.async
// instructions (fence), then wait on all the outstanding copy
// operations (wait<0>()). __syncthreads() alone does not do
// this.
//
// NOTE: cp_async_wait<0>() currently issues cp.async.wait_all.
// This is equivalent to cp.async.commit_group followed by
// cp.async_wait_group 0. This should make the first
// cp_async_fence() (which also issues cp.async.commit_group)
// redundant. The tutorial works as-is, so we'll leave the
// redundant fence in for now and study its removal later.
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
// Compute gemm on smem
gemm(tCsA, tCsB, tCrC);
__syncthreads();
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
template <typename TA, typename TB, typename TC,
typename Alpha, typename Beta>
void
gemm(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, ldA);
auto dB = make_stride(Int<1>{}, ldB);
auto dC = make_stride(Int<1>{}, ldC);
// Define block sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
// Define the block layouts (static)
auto sA = make_layout(make_shape(bM,bK));
auto sB = make_layout(make_shape(bN,bK));
auto sC = make_layout(make_shape(bM,bN));
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
dim3 dimBlock(size(tC));
dim3 dimGrid(ceil_div(size(M), size(bM)),
ceil_div(size(N), size(bN)));
gemm_device
<<< dimGrid, dimBlock, 0, stream >>>
(M, N, K,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
#include <cstdlib>
#include <cstdio>
#include <cassert>
void test_gemm(int m, int n, int k)
{
cute::device_init(0);
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
using TA = float;
using TB = float;
using TC = float;
using TI = float;
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
TI alpha = 1.0;
TI beta = 0.0;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
//
// cuBLas
//
cublasHandle_t handle;
cublasCreate(&handle);
// Run once
d_C = h_C;
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), m,
d_B.data().get(), n,
&beta,
d_C.data().get(), m);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cublas_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), m,
d_B.data().get(), n,
&beta,
d_C.data().get(), m);
}
double cublas_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000);
#else
std::cout << "Verification by comparison with cuBLAS is disabled, "
"either because the CMake option CUTLASS_ENABLE_CUBLAS "
"was explicitly set to OFF, or because CMake could not find cuBLAS. "
"If you would like to enable verification with cuBLAS, "
"please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, "
"rerun CMake, and recompile this example.\n";
#endif // CUTLASS_ENABLE_CUBLAS
//
// CuTe
//
// Run once (and check)
d_C = h_C;
gemm(m, n, k,
alpha,
d_A.data().get(), m,
d_B.data().get(), n,
beta,
d_C.data().get(), m);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(m, n, k,
alpha,
d_A.data().get(), m,
d_B.data().get(), n,
beta,
d_C.data().get(), m);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100);
auto host_matrix_to_const_column_major_cute_tensor =
[](const auto& X, int num_rows, int num_cols, int LDX) {
const auto shape = cute::Shape<int, int>{num_rows, num_cols};
const auto strides = cute::Stride<int, int>{1, LDX};
return cute::make_tensor(X.data(), cute::make_layout(shape, strides));
};
const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m);
// B^T is k x n, so B is n x k.
const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n);
const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m);
const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m);
print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view);
#endif // CUTLASS_ENABLE_CUBLAS
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
test_gemm(m, n, k);
return 0;
}

View File

@ -0,0 +1,79 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//
// Accept mutable temporaries
//
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> && y)
{
return axpby(alpha, x, beta, y);
}
//
// AXPBY
//
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> & y)
{
auto isBetaZero = (beta == Int<0>{});
CUTE_UNROLL
for (int i = 0; i < size(x); ++i) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
}
}
} // end namespace cute

View File

@ -0,0 +1,66 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/algorithm/fill.hpp>
namespace cute
{
//
// Accept mutable temporaries
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE
void
clear(Tensor<Engine, Layout>&& tensor)
{
return clear(tensor);
}
//
// Set elements to zero
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE
void
clear(Tensor<Engine, Layout>& tensor)
{
using T = typename Tensor<Engine,Layout>::value_type;
fill(tensor, T{});
}
} // end namespace cute

View File

@ -0,0 +1,262 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_predicate.hpp>
#include <cute/atom/copy_atom.hpp>
namespace cute
{
//
// Accept mutable temporaries
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(pred, src, dst);
}
template <class... CopyArgs,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(copy_atom, pred, src, dst);
}
template <class VecType,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_vec<VecType>(src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(src, dst);
}
template <class... CopyArgs,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(copy_atom, src, dst);
}
//
// copy_if -- Predicated Copy
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
auto copy_op = select_elementwise_copy(src, dst);
CUTE_UNROLL
for (int i = 0; i < size(src); ++i) {
if (pred(i)) {
copy_op.copy(src(i), dst(i));
}
}
}
//
// copy_if -- Predicated CopyAtom
//
template <class... CopyArgs,
class PredTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
PredTensor const& pred, // (Rest...)
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
copy_atom.call(src, dst);
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
auto src_v = group_modes<1,R>(src);
auto dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(src_v); ++i) {
if (pred(i)) {
copy_atom.call(src_v(_,i), dst_v(_,i));
}
}
}
}
//
// copy_vec -- attempt vectorized copy with VecType
//
template <class VecType,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
{
/* @pre is_aligned<N>(src.data()) &&
* is_aligned<N>(dst.data())
*/
auto src_v = recast<VecType const>(src);
auto dst_v = recast<VecType >(dst);
#if 0
if (thread0()) {
print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType)));
print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n");
print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n");
}
#endif
return copy_if(TrivialPredTensor{}, src_v, dst_v);
} else {
#if 0
if (thread0()) {
print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType)));
print(" "); print(layout(src)); print("\n");
print(" "); print(layout(dst)); print("\n");
}
#endif
return copy_if(TrivialPredTensor{}, src, dst);
}
}
//
// copy -- auto-vectorizing copy
//
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
constexpr int N = decltype(max_common_vector(src, dst))::value;
#if 0
if (thread0()) {
print("copy -- found a max_common_vector of %d\n", N);
print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n");
print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n");
}
#endif
if constexpr (N <= 1) {
return copy_if(TrivialPredTensor{}, src, dst);
} else {
constexpr int vec_bits = N * sizeof_bits<typename SrcEngine::value_type>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
return copy_vec<VecType>(src, dst);
}
}
//
// copy -- CopyAtom
//
template <class... CopyArgs,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy_if(copy_atom, TrivialPredTensor{}, src, dst);
}
template <class... CopyArgs,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<DefaultCopy, CopyArgs...> const&,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(src, dst);
}
} // end namespace cute

View File

@ -0,0 +1,87 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/algorithm/prefer.hpp>
namespace cute
{
//
// Accept mutable temporaries
//
template <class Engine, class Layout, class T>
CUTE_HOST_DEVICE
void
fill(Tensor<Engine, Layout>&& tensor, T const& value)
{
return fill(tensor, value);
}
namespace detail
{
// Prefer fill(tensor.data(), value), if possible
template <class Engine, class Layout, class T>
CUTE_HOST_DEVICE
auto
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<1>)
-> decltype(fill(tensor.data(), value))
{
fill(tensor.data(), value);
}
// Default implementation
template <class Engine, class Layout, class T>
CUTE_HOST_DEVICE
void
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<0>)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = value;
}
}
} // end namespace detail
template <class Engine, class Layout, class T>
CUTE_HOST_DEVICE
void
fill(Tensor<Engine, Layout>& tensor, T const& value)
{
return detail::fill(tensor, value, prefer<1>{});
}
} // end namespace cute

View File

@ -0,0 +1,198 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <utility>
#include <cute/config.hpp>
/** C++14 <functional> extensions */
namespace cute {
/**************/
/** Identity **/
/**************/
struct identity {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg);
}
};
template <class R>
struct constant_fn {
template <class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&&...) const {
return r_;
}
R r_;
};
/***********/
/** Unary **/
/***********/
#define CUTE_LEFT_UNARY_OP(NAME,OP) \
struct NAME { \
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP std::forward<T>(arg); \
} \
}
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
struct NAME { \
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return std::forward<T>(arg) OP ; \
} \
}
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
struct NAME { \
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP (std::forward<T>(arg)); \
} \
}
CUTE_LEFT_UNARY_OP(unary_plus, +);
CUTE_LEFT_UNARY_OP(negate, -);
CUTE_LEFT_UNARY_OP(bit_not, ~);
CUTE_LEFT_UNARY_OP(logical_not, !);
CUTE_LEFT_UNARY_OP(dereference, *);
CUTE_LEFT_UNARY_OP(address_of, &);
CUTE_LEFT_UNARY_OP(pre_increment, ++);
CUTE_LEFT_UNARY_OP(pre_decrement, --);
CUTE_RIGHT_UNARY_OP(post_increment, ++);
CUTE_RIGHT_UNARY_OP(post_decrement, --);
CUTE_NAMED_UNARY_OP(abs_fn, abs);
CUTE_NAMED_UNARY_OP(conjugate, cute::conj);
#undef CUTE_LEFT_UNARY_OP
#undef CUTE_RIGHT_UNARY_OP
#undef CUTE_NAMED_UNARY_OP
/************/
/** Binary **/
/************/
#define CUTE_BINARY_OP(NAME,OP) \
struct NAME { \
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
} \
}
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
struct NAME { \
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
} \
}
CUTE_BINARY_OP(plus, +);
CUTE_BINARY_OP(minus, -);
CUTE_BINARY_OP(multiplies, *);
CUTE_BINARY_OP(divides, /);
CUTE_BINARY_OP(modulus, %);
CUTE_BINARY_OP(plus_assign, +=);
CUTE_BINARY_OP(minus_assign, -=);
CUTE_BINARY_OP(multiplies_assign, *=);
CUTE_BINARY_OP(divides_assign, /=);
CUTE_BINARY_OP(modulus_assign, %=);
CUTE_BINARY_OP(bit_and, &);
CUTE_BINARY_OP(bit_or, |);
CUTE_BINARY_OP(bit_xor, ^);
CUTE_BINARY_OP(left_shift, <<);
CUTE_BINARY_OP(right_shift, >>);
CUTE_BINARY_OP(bit_and_assign, &=);
CUTE_BINARY_OP(bit_or_assign, |=);
CUTE_BINARY_OP(bit_xor_assign, ^=);
CUTE_BINARY_OP(left_shift_assign, <<=);
CUTE_BINARY_OP(right_shift_assign, >>=);
CUTE_BINARY_OP(logical_and, &&);
CUTE_BINARY_OP(logical_or, ||);
CUTE_BINARY_OP(equal_to, ==);
CUTE_BINARY_OP(not_equal_to, !=);
CUTE_BINARY_OP(greater, >);
CUTE_BINARY_OP(less, <);
CUTE_BINARY_OP(greater_equal, >=);
CUTE_BINARY_OP(less_equal, <=);
CUTE_NAMED_BINARY_OP(max_fn, cute::max);
CUTE_NAMED_BINARY_OP(min_fn, cute::min);
#undef CUTE_BINARY_OP
#undef CUTE_NAMED_BINARY_OP
/**********/
/** Meta **/
/**********/
template <class Fn, class Arg>
struct bound_fn {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(T&& arg) {
return fn_(arg_, std::forward<T>(arg));
}
Fn fn_;
Arg arg_;
};
template <class Fn, class Arg>
CUTE_HOST_DEVICE constexpr
auto
bind(Fn const& fn, Arg const& arg) {
return bound_fn<Fn,Arg>{fn, arg};
}
} // end namespace cute

View File

@ -0,0 +1,718 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/util/type_traits.hpp>
/** The gemm algorithm takes four (or three) tensors and computes
* D += A * B + C
* It dispatches based on the number of modes each tensor has:
*
* 1. `(V) x (V) => (V)`.
* The element-wise product of vectors. Dispatches to FMA or MMA.
* 2. `(M) x (N) => (M,N)`.
* The outer product of vectors. Dispatches to [3] with new mode K=(1).
* 3. `(M,K) x (N,K) => (M,N)`.
* The product of matrices. Dispatches to [5] with MMA vector-mode V.
* 4. `(V,M) x (V,N) => (V,M,N)`.
* The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n).
* 5. `(V,M,K) x (V,N,K) => (V,M,N)`.
* The batched product of matrices. Dispatches to [4] for each (k).
*/
namespace cute
{
//
// Three arguments to four
//
template <class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> & C)
{
return gemm(C, A, B, C);
}
template <class MMA,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> & C)
{
return gemm(mma, C, A, B, C);
}
//
// Accept mutable temporaries
//
template <class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> && C)
{
return gemm(C, A, B, C);
}
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TD, DLayout> && D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
return gemm(D, A, B, C);
}
template <class MMA,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> && C)
{
return gemm(mma, C, A, B, C);
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> && D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
return gemm(mma, D, A, B, C);
}
//
// Default MMA is UniversalFMA
//
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
using MMA = MMA_Atom<UniversalFMA<typename Tensor<TD,DLayout>::value_type,
typename Tensor<TA,ALayout>::value_type,
typename Tensor<TB,BLayout>::value_type,
typename Tensor<TC,CLayout>::value_type>>;
return gemm(MMA{}, D, A, B, C);
}
//
// Thread-Local Register-Memory GEMMs
//
// Dispatch [1]: (V) x (V) => (V)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 1 && is_rmem<TD>::value &&
ALayout::rank == 1 && is_rmem<TA>::value &&
BLayout::rank == 1 && is_rmem<TB>::value &&
CLayout::rank == 1 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (V) Logical data
Tensor<TA, ALayout> const& A, // (V) Logical data
Tensor<TB, BLayout> const& B, // (V) Logical data
Tensor<TC, CLayout> const& C) // (V) Logical data
{
// No static assertions on (V), MMA checks compatibility
mma.call(D, A, B, C);
}
// Dispatch [2]: (M) x (N) => (M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 1 && is_rmem<TA>::value &&
BLayout::rank == 1 && is_rmem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (M,N) Logical data
Tensor<TA, ALayout> const& A, // (M) Logical data
Tensor<TB, BLayout> const& B, // (N) Logical data
Tensor<TC, CLayout> const& C) // (M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
gemm(mma,
D, // (M,N)
make_tensor(A.data(), append<2>(A.layout())), // (M,1)
make_tensor(B.data(), append<2>(B.layout())), // (N,1)
C); // (M,N)
}
// Dispatch [3]: (M,K) x (N,K) => (M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_rmem<TA>::value &&
BLayout::rank == 2 && is_rmem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (M,N) Logical data
Tensor<TA, ALayout> const& A, // (M,K) Logical data
Tensor<TB, BLayout> const& B, // (N,K) Logical data
Tensor<TC, CLayout> const& C) // (M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
// Assert this is a 1-value MMA
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
}
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_rmem<TA>::value &&
BLayout::rank == 2 && is_rmem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
Tensor<TA, ALayout> const& A, // (V,M) Logical data
Tensor<TB, BLayout> const& B, // (V,N) Logical data
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
// REGISTER .reuse OPTIMIZATIONS
auto M = size<1>(A);
auto N = size<1>(B);
// 64-bit traversal specialization -- serpentine path
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 8 &&
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 8)
{
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
// Row-major iteration
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate
gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns));
}
}
#else
// Col-major iteration
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
}
}
#endif
} else
// 32-bit traversal specialization -- kinked serpentine path
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 4 &&
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 4)
{
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
// Row-major iteration
CUTE_UNROLL
for (int m = 0; m < M; m += 2) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
int ns = (m & 2) ? N-1-n : n;
gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns));
if (m+1 < M) {
gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns));
}
}
}
#else
// Col-major iteration
CUTE_UNROLL
for (int n = 0; n < N; n += 2) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
// Kinked serpentine traversal for maximum register reuse
int ms = (n & 2) ? M-1-m : m;
gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0));
if (n+1 < N) {
gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1));
}
}
}
#endif
} else {
// Fallback to serpentine loop
// Col-major iteration
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
}
}
}
}
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 3 && is_rmem<TA>::value &&
BLayout::rank == 3 && is_rmem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto K = size<2>(A);
CUTE_UNROLL
for (int k = 0; k < K; ++k) {
gemm(mma, D, A(_,_,k), B(_,_,k), C);
}
}
//
// Thread-Local Shared-Memory GEMMs
//
// Dispatch [1]: (V) x (V) => (V)
// Dispatch [2]: (M) x (N) => (M,N)
// Dispatch [3]: (M,K) x (N,K) => (M,N)
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
// Dispatch [3]: (M,K) x (N,K) => (M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (M,N) Logical data
Tensor<TA, ALayout> const& A, // (M,K) Logical data
Tensor<TB, BLayout> const& B, // (N,K) Logical data
Tensor<TC, CLayout> const& C) // (M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
// Assert this is a 1-value MMA
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
}
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 3 && is_smem<TA>::value &&
BLayout::rank == 3 && is_smem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
{
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto rA = MMA_Atom<MMA>::make_fragment_A(A);
auto rB = MMA_Atom<MMA>::make_fragment_B(B);
auto K = size<2>(A);
CUTE_UNROLL
for (int k = 0; k < K; ++k)
{
copy(A(_,_,k), rA(_,_,k));
copy(B(_,_,k), rB(_,_,k));
// Thread-level register gemm for k
gemm(mma, D, rA(_,_,k), rB(_,_,k), C);
}
}
//
// Collective Shared-Memory GEMMs
//
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept and return value of type TA::value_type");
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept and return value of type TB::value_type");
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
// Compute the "residues"
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
// Shift the origin so k_residue is zeroth tile
sA.data() = &sA(0,k_residue);
sB.data() = &sB(0,k_residue);
#if 0
if (thread0()) {
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
}
#endif
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
#if 0
if (thread0()) {
print(rounded_sA.layout()); print("\n");
print(rounded_sB.layout()); print("\n");
print(rounded_sC.layout()); print("\n");
}
#endif
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print(tCsA.layout()); print("\n");
print(tCsB.layout()); print("\n");
print(tCsC.layout()); print("\n");
print(tCrA.layout()); print("\n");
print(tCrB.layout()); print("\n");
print(tCrC.layout()); print("\n");
}
#endif
//
// PREDICATION
//
// Allocate the preds for only the MMA-mode of tCsA and tCsB
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
// Create coordinate tensors on a single compute block for predication
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
// Populate the m and n predicates
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
}
#if 0
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
threadIdx.x,
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
#endif
//
// PREFETCH k_block = 0 (with k-predication)
//
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1;
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsC); ++m)
{
CUTE_UNROLL
for (int n = 0; n < size<2>(tCsC); ++n)
{
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsC); ++i)
{
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
{
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
}
}
}
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
} // end namespace cute

View File

@ -0,0 +1,46 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
namespace cute
{
// Infinite types that inherit from each other
template <std::size_t N>
struct prefer : prefer<N-1> {};
template <>
struct prefer<0> {};
// Can be used to preferencially overload implementations
// Higher N in prefer<N> have higher priority.
} // end namespace cute

View File

@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/** Common algorithms on (hierarchical) tensors */
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//
// for_each
//
template <class Engine, class Layout, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
}
}
template <class Engine, class Layout, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
}
}
// Accept mutable temporaries
template <class Engine, class Layout, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return for_each(tensor, static_cast<UnaryOp&&>(op));
}
//
// transform
//
// Similar to std::transform but does not return number of elements affected
template <class Engine, class Layout, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
}
}
// Accept mutable temporaries
template <class Engine, class Layout, class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return transform(tensor, std::forward<UnaryOp>(op));
}
} // end namespace cute

View File

@ -0,0 +1,846 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/util/type_traits.hpp>
/** Common algorithms on (hierarchical) tuples */
/** Style choice:
* Forward params [using static_cast<T&&>(.)] for const/non-const/ref/non-ref args
* but don't bother forwarding functions as ref-qualified member fns are extremely rare
*/
namespace cute
{
//
// Apply (Unpack)
// (t, f) => f(t_0,t_1,...,t_n)
//
namespace detail {
template <class T, class F, int... I>
CUTE_HOST_DEVICE constexpr
auto
apply(T&& t, F&& f, seq<I...>)
{
return f(get<I>(static_cast<T&&>(t))...);
}
} // end namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
apply(T&& t, F&& f)
{
return detail::apply(static_cast<T&&>(t), f, tuple_seq<T>{});
}
//
// Transform Apply
// (t, f, g) => g(f(t_0),f(t_1),...)
//
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr
auto
tapply(T&& t, F&& f, G&& g, seq<I...>)
{
return g(f(get<I>(static_cast<T&&>(t)))...);
}
template <class T0, class T1, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr
auto
tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq<I...>)
{
return g(f(get<I>(static_cast<T0&&>(t0)),
get<I>(static_cast<T1&&>(t1)))...);
}
template <class T0, class T1, class T2, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr
auto
tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq<I...>)
{
return g(f(get<I>(static_cast<T0&&>(t0)),
get<I>(static_cast<T1&&>(t1)),
get<I>(static_cast<T2&&>(t2)))...);
}
} // end namespace detail
template <class T, class F, class G>
CUTE_HOST_DEVICE constexpr
auto
transform_apply(T&& t, F&& f, G&& g)
{
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
}
template <class T0, class T1, class F, class G>
CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
}
template <class T0, class T1, class T2, class F, class G>
CUTE_HOST_DEVICE constexpr
auto
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
{
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
}
//
// For Each
// (t, f) => f(t_0),f(t_1),...,f(t_n)
//
template <class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_each(T&& t, F&& f)
{
detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
for_each_leaf(T&& t, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
return detail::apply(static_cast<T&&>(t), [&](auto&&... a){ return (for_each_leaf(static_cast<decltype(a)&&>(a), f), ...); }, tuple_seq<T>{});
} else {
return f(static_cast<T&&>(t));
}
CUTE_GCC_UNREACHABLE;
}
//
// Transform
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
//
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
transform(T const& t, F&& f)
{
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
}
template <class T0, class T1, class F>
CUTE_HOST_DEVICE constexpr
auto
transform(T0 const& t0, T1 const& t1, F&& f)
{
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
}
template <class T0, class T1, class T2, class F>
CUTE_HOST_DEVICE constexpr
auto
transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
{
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
transform_leaf(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return transform(t, [&](auto const& a) { return transform_leaf(a, f); });
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
//
// find and find_if
//
namespace detail {
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<>)
{
return cute::integral_constant<int, tuple_size<T>::value>{};
}
template <class T, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f, seq<I,Is...>)
{
if constexpr (decltype(f(get<I>(t)))::value) {
return cute::integral_constant<int, I>{};
} else {
return find_if(t, f, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
find_if(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return detail::find_if(t, f, tuple_seq<T>{});
} else {
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
find(T const& t, X const& x)
{
return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false
}
template <class T, class F>
auto
none_of(T const& t, F&& f)
{
return cute::integral_constant<bool, decltype(find_if(t, f))::value == std::tuple_size<T>::value>{};
}
template <class T, class F>
auto
all_of(T const& t, F&& f)
{
auto not_f = [&](auto const& a) { return !f(a); };
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == std::tuple_size<T>::value>{};
}
template <class T, class F>
auto
any_of(T const& t, F&& f)
{
return cute::integral_constant<bool, !decltype(none_of(t, f))::value>{};
}
//
// Filter
// (t, f) => <f(t_0),f(t_1),...,f(t_n)>
//
template <class T, class F>
CUTE_HOST_DEVICE constexpr
auto
filter_tuple(T const& t, F&& f)
{
return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
template <class T0, class T1, class F>
CUTE_HOST_DEVICE constexpr
auto
filter_tuple(T0 const& t0, T1 const& t1, F&& f)
{
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
}
//
// Fold (Reduce, Accumulate)
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
//
namespace detail {
// This impl compiles much faster than cute::apply and variadic args
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
decltype(auto)
fold(T&& t, V&& v, F&& f, seq<>)
{
return static_cast<V&&>(v);
}
template <class T, class V, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
decltype(auto)
fold(T&& t, V&& v, F&& f, seq<I,Is...>)
{
if constexpr (sizeof...(Is) == 0) {
return f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t)));
} else {
return fold(static_cast<T&&>(t),
f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t))),
f,
seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
auto
fold(T&& t, V&& v, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
static_cast<V&&>(v),
f,
tuple_seq<T>{});
} else {
return f(static_cast<V&&>(v), static_cast<T&&>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class F>
CUTE_HOST_DEVICE constexpr
decltype(auto)
fold_first(T&& t, F&& f)
{
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
return detail::fold(static_cast<T&&>(t),
get<0>(static_cast<T&&>(t)),
f,
make_range<1,std::tuple_size<std::remove_reference_t<T>>::value>{});
} else {
return static_cast<T&&>(t);
}
CUTE_GCC_UNREACHABLE;
}
//
// front, back, take, unwrap
//
// Get the first non-tuple element in a hierarchical tuple
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
front(T&& t)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
return front(get<0>(static_cast<T&&>(t)));
} else {
return static_cast<T&&>(t);
}
CUTE_GCC_UNREACHABLE;
}
// Get the last non-tuple element in a hierarchical tuple
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
back(T&& t)
{
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
constexpr int N = tuple_size<remove_cvref_t<T>>::value;
return back(get<N-1>(static_cast<T&&>(t)));
} else {
return static_cast<T&&>(t);
}
CUTE_GCC_UNREACHABLE;
}
// Takes the elements in the range [B,E)
template <int B, int E, class T>
CUTE_HOST_DEVICE constexpr
auto
take(T const& t)
{
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
}
// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple
template <class T>
CUTE_HOST_DEVICE constexpr
auto
unwrap(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (tuple_size<T>::value == 1) {
return unwrap(get<0>(t));
} else {
return t;
}
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
//
// Flatten a hierarchical tuple to a tuple of depth one.
//
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten_to_tuple(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
} else {
return cute::make_tuple(t);
}
CUTE_GCC_UNREACHABLE;
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
flatten(T const& t)
{
if constexpr (is_tuple<T>::value) {
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
//
// insert and remove and replace
//
namespace detail {
// Shortcut around tuple_cat for common insert/remove/repeat cases
template <class T, class X, int... I, int... J, int... K>
CUTE_HOST_DEVICE constexpr
auto
construct(T const& t, X const& x, seq<I...>, seq<J...>, seq<K...>)
{
return cute::make_tuple(get<I>(t)..., (void(J),x)..., get<K>(t)...);
}
} // end namespace detail
// Insert x into the Nth position of the tuple
template <int N, class T, class X>
CUTE_HOST_DEVICE constexpr
auto
insert(T const& t, X const& x)
{
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N,tuple_size<T>::value>{});
}
// Remove the Nth element of the tuple
template <int N, class T>
CUTE_HOST_DEVICE constexpr
auto
remove(T const& t)
{
return detail::construct(t, 0, make_seq<N>{}, seq<>{}, make_range<N+1,tuple_size<T>::value>{});
}
// Replace the Nth element of the tuple with x
template <int N, class T, class X>
CUTE_HOST_DEVICE constexpr
auto
replace(T const& t, X const& x)
{
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
}
// Replace the first element of the tuple with x
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
replace_front(T const& t, X const& x)
{
if constexpr (is_tuple<T>::value) {
return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size<T>::value>{});
} else {
return x;
}
CUTE_GCC_UNREACHABLE;
}
// Replace the last element of the tuple with x
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
replace_back(T const& t, X const& x)
{
if constexpr (is_tuple<T>::value) {
return detail::construct(t, x, make_seq<tuple_size<T>::value-1>{}, seq<0>{}, seq<>{});
} else {
return x;
}
CUTE_GCC_UNREACHABLE;
}
//
// Make a tuple of Xs of tuple_size N
//
template <int N, class X>
CUTE_HOST_DEVICE constexpr
auto
repeat(X const& x)
{
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
}
//
// Make a tuple of Xs the same profile as tuple
//
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
repeat_like(T const& t, X const& x)
{
if constexpr (is_tuple<T>::value) {
return transform(t, [&](auto const& a) { return repeat_like(a,x); });
} else {
return x;
}
CUTE_GCC_UNREACHABLE;
}
// Group the elements [B,E) of a T into a single element
// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{})
// => T<_1,_2,T<_3,_4>,_5,_6>{}
template <int B, int E, class T>
CUTE_HOST_DEVICE constexpr
auto
group(T const& t)
{
return detail::construct(t, take<B,E>(t), make_seq<B>{}, seq<0>{}, make_range<E,tuple_size<T>::value>{});
}
//
// Extend a T to rank N by appending/prepending an element
//
template <int N, class T, class X>
CUTE_HOST_DEVICE constexpr
auto
append(T const& a, X const& x)
{
if constexpr (is_tuple<T>::value) {
if constexpr (N == tuple_size<T>::value) {
return a;
} else {
static_assert(N > tuple_size<T>::value);
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, make_seq<N-tuple_size<T>::value>{}, seq<>{});
}
} else {
if constexpr (N == 1) {
return a;
} else {
return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq<N-1>{}, seq<>{});
}
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
append(T const& a, X const& x)
{
if constexpr (is_tuple<T>::value) {
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, seq<0>{}, seq<>{});
} else {
return cute::make_tuple(a, x);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class T, class X>
CUTE_HOST_DEVICE constexpr
auto
prepend(T const& a, X const& x)
{
if constexpr (is_tuple<T>::value) {
if constexpr (N == tuple_size<T>::value) {
return a;
} else {
static_assert(N > tuple_size<T>::value);
return detail::construct(a, x, seq<>{}, make_seq<N-tuple_size<T>::value>{}, make_seq<tuple_size<T>::value>{});
}
} else {
if constexpr (N == 1) {
return a;
} else {
static_assert(N > 1);
return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq<N-1>{}, seq<0>{});
}
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
prepend(T const& a, X const& x)
{
if constexpr (is_tuple<T>::value) {
return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq<tuple_size<T>::value>{});
} else {
return cute::make_tuple(x, a);
}
CUTE_GCC_UNREACHABLE;
}
//
// Inclusive scan (prefix sum)
//
namespace detail {
template <class T, class V, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
auto
iscan(T const& t, V const& v, F&& f, seq<I,Is...>)
{
// Apply the function to v and the element at I
auto v_next = f(v, get<I>(t));
// Replace I with v_next
auto t_next = replace<I>(t, v_next);
#if 0
std::cout << "ISCAN i" << I << std::endl;
std::cout << " t " << t << std::endl;
std::cout << " i " << v << std::endl;
std::cout << " f(i,t) " << v_next << std::endl;
std::cout << " t_n " << t_next << std::endl;
#endif
if constexpr (sizeof...(Is) == 0) {
return t_next;
} else {
return iscan(t_next, v_next, f, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
auto
iscan(T const& t, V const& v, F&& f)
{
return detail::iscan(t, v, f, tuple_seq<T>{});
}
//
// Exclusive scan (prefix sum)
//
namespace detail {
template <class T, class V, class F, int I, int... Is>
CUTE_HOST_DEVICE constexpr
auto
escan(T const& t, V const& v, F&& f, seq<I,Is...>)
{
if constexpr (sizeof...(Is) == 0) {
// Replace I with v
return replace<I>(t, v);
} else {
// Apply the function to v and the element at I
auto v_next = f(v, get<I>(t));
// Replace I with v
auto t_next = replace<I>(t, v);
#if 0
std::cout << "ESCAN i" << I << std::endl;
std::cout << " t " << t << std::endl;
std::cout << " i " << v << std::endl;
std::cout << " f(i,t) " << v_next << std::endl;
std::cout << " t_n " << t_next << std::endl;
#endif
// Recurse
return escan(t_next, v_next, f, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class T, class V, class F>
CUTE_HOST_DEVICE constexpr
auto
escan(T const& t, V const& v, F&& f)
{
return detail::escan(t, v, f, tuple_seq<T>{});
}
//
// Zip (Transpose)
//
// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input
// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output
namespace detail {
template <int J, class T, int... Is>
CUTE_HOST_DEVICE constexpr
auto
zip_(T const& t, seq<Is...>)
{
return cute::make_tuple(get<J>(get<Is>(t))...);
}
template <class T, int... Is, int... Js>
CUTE_HOST_DEVICE constexpr
auto
zip(T const& t, seq<Is...>, seq<Js...>)
{
static_assert(conjunction<bool_constant<tuple_size<tuple_element_t<0,T>>::value == tuple_size<tuple_element_t<Is,T>>::value>...>::value, "Mismatched Ranks");
return cute::make_tuple(detail::zip_<Js>(t, seq<Is...>{})...);
}
} // end namespace detail
template <class T>
CUTE_HOST_DEVICE constexpr
auto
zip(T const& t)
{
if constexpr (is_tuple<T>::value) {
if constexpr (is_tuple<tuple_element_t<0,T>>::value) {
return detail::zip(t, tuple_seq<T>{}, tuple_seq<tuple_element_t<0,T>>{});
} else {
return cute::make_tuple(t);
}
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
// Convenient to pass them in separately
template <class T0, class T1, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
zip(T0 const& t0, T1 const& t1, Ts const&... ts)
{
return zip(cute::make_tuple(t0, t1, ts...));
}
//
// zip2_by -- A guided zip for rank-2 tuples
// Take a tuple like ((A,a),((B,b),(C,c)),d)
// and produce a tuple ((A,(B,C)),(a,(b,c),d))
// where the rank-2 modes are selected by the terminals of the guide (X,(X,X))
//
namespace detail {
template <class T, class TG, int... Is, int... Js>
CUTE_HOST_DEVICE constexpr
auto
zip2_by(T const& t, TG const& guide, seq<Is...>, seq<Js...>)
{
// zip2_by produces the modes like ((A,a),(B,b),...)
auto split = cute::make_tuple(zip2_by(get<Is>(t), get<Is>(guide))...);
// Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y))
return cute::make_tuple(cute::make_tuple(get<Is,0>(split)...),
cute::make_tuple(get<Is,1>(split)..., get<Js>(t)...));
}
} // end namespace detail
template <class T, class TG>
CUTE_HOST_DEVICE constexpr
auto
zip2_by(T const& t, TG const& guide)
{
if constexpr (is_tuple<TG>::value) {
constexpr int TR = tuple_size<T>::value;
constexpr int GR = tuple_size<TG>::value;
static_assert(TR >= GR, "Mismatched ranks");
return detail::zip2_by(t, guide,
make_range< 0, GR>{},
make_range<GR, TR>{});
} else {
static_assert(tuple_size<T>::value == 2, "Mismatched ranks");
return t;
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace cute

View File

@ -0,0 +1,190 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))))
# define CUTE_ARCH_CLUSTER_SM90_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED
#endif
namespace cute {
CUTE_DEVICE void cluster_arrive_relaxed()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : );
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTE_DEVICE void cluster_arrive()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.arrive.aligned;\n" : : );
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTE_DEVICE void cluster_wait()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.wait.aligned;\n" : : );
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTE_DEVICE void cluster_sync()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
cluster_arrive();
cluster_wait();
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Returns the dim3 grid size in terms of number of clusters.
CUTE_DEVICE dim3 cluster_grid_dims()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return gridDim;
#endif
}
// Returns the dim3 cluster rank in the grid.
CUTE_DEVICE dim3 cluster_id_in_grid()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return blockIdx;
#endif
}
// Returns the relative dim3 block rank local to the cluster.
CUTE_DEVICE dim3 block_id_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {0,0,0};
#endif
}
// Returns the dim3 cluster shape.
CUTE_DEVICE dim3 cluster_shape()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {1,1,1};
#endif
}
// Get 1D ctaid in a cluster.
CUTLASS_DEVICE uint32_t block_rank_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t rank;
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
return rank;
#else
return 0;
#endif
}
// Set the destination block-ID in cluster for a given SMEM Address
CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank)
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t result;
asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n"
: "=r"(result)
: "r"(smemAddr), "r"(rank));
return result;
#else
return smemAddr;
#endif
}
// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false.
CUTE_HOST_DEVICE uint32_t elect_one_sync()
{
#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED)
uint32_t pred = 0;
uint32_t laneid = 0;
asm volatile(
"{\n"
".reg .b32 %rx;\n"
".reg .pred %px;\n"
" elect.sync %rx|%px, %2;\n"
"@%px mov.s32 %1, 1;\n"
" mov.s32 %0, %rx;\n"
"}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
return pred;
#elif defined(__CUDA_ARCH__)
return (threadIdx.x % 32) == 0;
#else
return true;
#endif
}
} // end namespace cute

View File

@ -0,0 +1,71 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/numeric/uint128.hpp>
namespace cute
{
//
// Direct Copy for any type
//
template <class S, class D = S>
struct UniversalCopy
{
using SRegisters = S[1];
using DRegisters = D[1];
CUTE_HOST_DEVICE static constexpr void
copy(S const& src,
D & dst)
{
dst = src;
}
};
//
// Placeholder for the copy algorithm's default, auto-vectorizing behavior
//
struct DefaultCopy
{
using SRegisters = uint128_t[1];
using DRegisters = uint128_t[1];
};
using AutoVectorizingCopy = DefaultCopy;
} // end namespace cute

View File

@ -0,0 +1,215 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
# define CUTE_ARCH_LDSM_SM75_ENABLED
#endif
namespace cute
{
struct SM75_U32x1_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
struct SM75_U32x2_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
struct SM75_U32x4_LDSM_N
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
struct SM75_U16x2_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
struct SM75_U16x4_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
struct SM75_U16x8_LDSM_T
{
using SRegisters = uint128_t[1];
using DRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
#endif
}
};
//
// Legacy LDSM interfaces that aren't very useful
//
template <class T>
CUTE_HOST_DEVICE
void
copy_ldsm(uint128_t const* const smem_ptr,
T* rmem_ptr)
{
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
// if constexpr
if (sizeof(T) == 4) {
SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]);
}
else if (sizeof(T) == 8) {
SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
}
else if (sizeof(T) == 16) {
SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
}
else {
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
}
}
template <class T>
CUTE_HOST_DEVICE
void
copy_ldsm_trans(uint128_t const* const smem_ptr,
T* rmem_ptr)
{
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
// if constexpr
if (sizeof(T) == 4) {
SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]);
}
else if (sizeof(T) == 8) {
SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
}
else if (sizeof(T) == 16) {
SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
}
else {
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
}
}
} // end namespace cute

View File

@ -0,0 +1,138 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED
#endif
namespace cute
{
/// Copy via cp.async with caching at all levels
template <class TS, class TD = TS>
struct SM80_CP_ASYNC_CACHEALWAYS
{
using SRegisters = TS[1];
using DRegisters = TD[1];
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
CUTE_HOST_DEVICE static void
copy(TS const& gmem_src,
TD & smem_dst)
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
TS const* gmem_ptr = &gmem_src;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n"
:: "r"(smem_int_ptr),
"l"(gmem_ptr),
"n"(sizeof(TS)));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
#endif
}
};
/// Copy via cp.async with caching at global level
template <class TS, class TD = TS>
struct SM80_CP_ASYNC_CACHEGLOBAL
{
using SRegisters = TS[1];
using DRegisters = TD[1];
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
CUTE_HOST_DEVICE static void
copy(TS const& gmem_src,
TD & smem_dst)
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
TS const* gmem_ptr = &gmem_src;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n"
:: "r"(smem_int_ptr),
"l"(gmem_ptr),
"n"(sizeof(TS)));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
CUTE_HOST_DEVICE
void
cp_async_fence()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Blocks until all but N previous cp.async.commit_group operations have committed.
template <int N>
CUTE_HOST_DEVICE
void
cp_async_wait()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
}
#endif
}
template <int N>
CUTE_HOST_DEVICE
void
cp_async_wait(Int<N>)
{
return cp_async_wait<N>();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace cute

View File

@ -0,0 +1,225 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
# define CUTE_ARCH_STSM_SM90_ENABLED
# define CUTE_ARCH_TMA_SM90_ENABLED
#endif
namespace cute
{
struct SM90_U32x1_STSM_N
{
using SRegisters = uint32_t[1];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src,
uint128_t & smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n"
:: "r"(smem_int_ptr),
"r"(src));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
struct SM90_U32x2_STSM_N
{
using SRegisters = uint32_t[2];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src0, uint32_t const& src1,
uint128_t& smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
struct SM90_U32x4_STSM_N
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
uint128_t& smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
struct SM90_U16x2_STSM_T
{
using SRegisters = uint32_t[1];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src,
uint128_t& smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n"
:: "r"(smem_int_ptr),
"r"(src));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
struct SM90_U16x4_STSM_T
{
using SRegisters = uint32_t[2];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src0, uint32_t const& src1,
uint128_t& smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
struct SM90_U16x8_STSM_T
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
uint128_t& smem_dst)
{
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
//
// Legacy STSM interfaces that aren't very useful
//
template <class T>
CUTE_HOST_DEVICE
void
copy_stsm(T const* const rmem_ptr,
uint128_t* const smem_ptr)
{
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
// if constexpr
if (sizeof(T) == 4) {
SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]);
}
else if (sizeof(T) == 8) {
SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
}
else if (sizeof(T) == 16) {
SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
}
else {
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
}
}
template <class T>
CUTE_HOST_DEVICE
void
copy_stsm_trans(T const* const rmem_ptr,
uint128_t* const smem_ptr)
{
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
// if constexpr
if (sizeof(T) == 4) {
SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]);
}
else if (sizeof(T) == 8) {
SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
}
else if (sizeof(T) == 16) {
SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
}
else {
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,194 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda.h>
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
#include <cute/arch/copy_sm90.hpp>
#include <cute/container/alignment.hpp>
#include <cute/container/bit_field.hpp>
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
#include <cute/numeric/half.hpp> // to_Format<half_t>
namespace cute
{
//////////////////////////////////////////////////////////////////////////////////////////////////////
/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns
/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels)
/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction)
//////////////////////////////////////////////////////////////////////////////////////////////////////
// Initialize barrier present in shared memory
CUTE_HOST_DEVICE
void
initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
int thread_count = 1) // Thread count expected to arrive/wait on this barrier
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n"
:: "r"(smem_int_ptr),
"r"(thread_count));
#endif
}
// Set the number of bytes transfered per transaction
CUTE_HOST_DEVICE
void
set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
uint32_t bytes) // Number of bytes transfered by per TMA transaction
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n"
:: "r"(smem_int_ptr),
"r"(bytes));
#endif
}
// Barrier wait
CUTE_HOST_DEVICE
void
wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
int phase_bit) // Current phase bit the barrier waiting to flip
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n"
:: "r"(smem_int_ptr),
"r"(phase_bit));
#endif
}
// Barrier arrive
CUTE_HOST_DEVICE
void
arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .b64 state; \n"
"mbarrier.arrive.shared.b64 state, [%0];\n"
"}\n"
:: "r"(smem_int_ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// TMA Descriptor and utilities
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace TMA {
enum class SmemSwizzleBits : uint8_t {
DISABLE = 0,
B32 = 1,
B64 = 2,
B128 = 3,
};
#if (__CUDACC_VER_MAJOR__ >= 12)
template <class T>
inline CUtensorMapDataType to_CUtensorMapDataType() {
if constexpr (std::is_same<T, int8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (std::is_same<T, uint8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (std::is_same<T, uint16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
if constexpr (std::is_same<T, uint32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
if constexpr (std::is_same<T, uint64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else
if constexpr (std::is_same<T, int32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else
if constexpr (std::is_same<T, int64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else
if constexpr (std::is_same<T, half_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else
if constexpr (std::is_same<T, float>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else
if constexpr (std::is_same<T, double>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
if constexpr (std::is_same<T, bfloat16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
if constexpr (std::is_same<T, tfloat32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
}
inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) {
switch (t) {
default: assert(false && "Unknown SmemSwizzleBits!");
case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE;
case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B;
case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B;
case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B;
}
}
#endif // (__CUDACC_VER_MAJOR__ >= 12)
} // end namespace TMA
#if (__CUDACC_VER_MAJOR__ >= 12)
using TmaDescriptor = CUtensorMap;
#else
using TmaDescriptor = struct { char bytes[128]; };
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Initiates a TensorMap Prefetch
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
void
prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param)
asm volatile (
"prefetch.tensormap [%0];"
:
: "l"(gmem_int_desc)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
///////////////////////////////////////////////////////////////////////////////
} // end namespace cute

View File

@ -0,0 +1,552 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
#include <cute/arch/copy_sm90.hpp>
namespace cute
{
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_1D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_2D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_3D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_4D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_5D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0)
{
return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_1D_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%4}], [%2], %3;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"h"(multicast_mask),
"r"(crd0)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_2D_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%4, %5}], [%2], %3;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"h"(multicast_mask),
"r"(crd0), "r"(crd1)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_3D_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%4, %5, %6}], [%2], %3;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"h"(multicast_mask),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_4D_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%4, %5, %6, %7}], [%2], %3;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"h"(multicast_mask),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_5D_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
" [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"h"(multicast_mask),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_LOAD_MULTICAST
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0)
{
return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_STORE : Initiates a TMA copy from shared memory to global memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_STORE_1D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(crd0)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_2D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(crd0), "r"(crd1)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_3D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_4D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE_5D
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
};
struct SM90_TMA_STORE
{
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0)
{
return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* const desc_ptr,
void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
};
// Indicate arrival of warp issuing TMA_STORE
CUTE_HOST_DEVICE static void
tma_store_arrive() {
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
asm volatile("cp.async.bulk.commit_group;");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
// Wait on prior N (Count) TMA_STORE instructions to complete
template<int Count>
CUTE_HOST_DEVICE static void
tma_store_wait() {
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
asm volatile(
"cp.async.bulk.wait_group.read %0;"
:
: "n"(Count)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace cute

64
include/cute/arch/mma.hpp Normal file
View File

@ -0,0 +1,64 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
namespace cute
{
//
// Direct FMA for any type
//
template <class D, class A = D, class B = A, class C = D>
struct UniversalFMA
{
using DRegisters = D[1];
using ARegisters = A[1];
using BRegisters = B[1];
using CRegisters = C[1];
CUTE_HOST_DEVICE static constexpr void
fma(D & d,
A const& a,
B const& b,
C const& c)
{
// Forward to an ADL/cute free function for these types
using cute::fma;
fma(d, a, b, c);
}
};
} // end namespace cute

View File

@ -0,0 +1,87 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
# define CUTE_ARCH_MMA_SM61_ENABLED
#endif
namespace cute
{
struct SM61_DP4A
{
using DRegisters = int32_t[1];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = int32_t[1];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
{
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED");
#endif
}
};
struct SM61_DP2A
{
using DRegisters = int32_t[1];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = int32_t[1];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
{
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED");
#endif
}
};
} // namespace cute

View File

@ -0,0 +1,329 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
// Config
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
# define CUTE_ARCH_MMA_SM70_SUPPORTED
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
# define CUTE_ARCH_MMA_SM70_ENABLED
# endif
#endif
namespace cute
{
//
// SM70 MMA 884 F16F16F16
//
struct SM70_8x8x4_F16F16F16F16_TN
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6, %7},"
"{%8, %9, %10, %11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F16F16F16F16_NT
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6, %7},"
"{%8, %9, %10, %11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F16F16F16F16_NN
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6, %7},"
"{%8, %9, %10, %11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F16F16F16F16_TT
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6, %7},"
"{%8, %9, %10, %11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// SM70 MMA 884 F16F16F32
//
struct SM70_8x8x4_F32F16F16F32_TN
{
using DRegisters = float[8];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = float[8];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
float & d4, float & d5, float & d6, float & d7,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
float const& c0, float const& c1, float const& c2, float const& c3,
float const& c4, float const& c5, float const& c6, float const& c7)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11},"
"{%12, %13, %14, %15, %16, %17, %18, %19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F32F16F16F32_NT
{
using DRegisters = float[8];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = float[8];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
float & d4, float & d5, float & d6, float & d7,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
float const& c0, float const& c1, float const& c2, float const& c3,
float const& c4, float const& c5, float const& c6, float const& c7)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11},"
"{%12, %13, %14, %15, %16, %17, %18, %19};"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F32F16F16F32_NN
{
using DRegisters = float[8];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = float[8];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
float & d4, float & d5, float & d6, float & d7,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
float const& c0, float const& c1, float const& c2, float const& c3,
float const& c4, float const& c5, float const& c6, float const& c7)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11},"
"{%12, %13, %14, %15, %16, %17, %18, %19};"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM70_8x8x4_F32F16F16F32_TT
{
using DRegisters = float[8];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = float[8];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
float & d4, float & d5, float & d6, float & d7,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0, uint32_t const& b1,
float const& c0, float const& c1, float const& c2, float const& c3,
float const& c4, float const& c5, float const& c6, float const& c7)
{
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11},"
"{%12, %13, %14, %15, %16, %17, %18, %19};"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1),
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace cute

View File

@ -0,0 +1,120 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
// Config
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
# define CUTE_ARCH_MMA_SM75_SUPPORTED
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
# define CUTE_ARCH_MMA_SM75_ENABLED
# endif
#endif
namespace cute
{
//
// SM75 MMA 1688 F16F16F32
//
struct SM75_16x8x8_F32F16F16F32_TN
{
using DRegisters = float[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[1];
using CRegisters = float[4];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(float & d0, float & d1, float & d2, float & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0,
float const& c0, float const& c1, float const& c2, float const& c3)
{
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1),
"r"(b0),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// SM75 MMA 8816 S8S8S32
//
struct SM75_8x8x16_S32S8S8S32_TN
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[2];
// Register asm fma
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32"
"{%0, %1},"
"{%2},"
"{%3},"
"{%4, %5};\n"
: "=r"(d0), "=r"(d1)
: "r"(a0),
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace cute

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,961 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
# define CUTE_ARCH_MMA_SM90_ENABLED
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x4 TN
struct SM90_16x8x4_F64F64F64F64_TN
{
using DRegisters = double[4];
using ARegisters = double[2];
using BRegisters = double[1];
using CRegisters = double[4];
CUTE_HOST_DEVICE static void
fma(double & d0, double & d1, double & d2, double & d3,
double const& a0, double const& a1,
double const& b0,
double const& c0, double const& c1, double const& c2, double const& c3)
{
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
: "d"(a0), "d"(a1),
"d"(b0),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x8 TN
struct SM90_16x8x8_F64F64F64F64_TN
{
using DRegisters = double[4];
using ARegisters = double[4];
using BRegisters = double[2];
using CRegisters = double[4];
CUTE_HOST_DEVICE static void
fma(double & d0, double & d1, double & d2, double & d3,
double const& a0, double const& a1, double const& a2, double const& a3,
double const& b0, double const& b1,
double const& c0, double const& c1, double const& c2, double const& c3)
{
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
"d"(b0), "d"(b1),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x16 TN
struct SM90_16x8x16_F64F64F64F64_TN
{
using DRegisters = double[4];
using ARegisters = double[8];
using BRegisters = double[4];
using CRegisters = double[4];
CUTE_HOST_DEVICE static void
fma(double & d0, double & d1, double & d2, double & d3,
double const& a0, double const& a1, double const& a2, double const& a3,
double const& a4, double const& a5, double const& a6, double const& a7,
double const& b0, double const& b1, double const& b2, double const& b3,
double const& c0, double const& c1, double const& c2, double const& c3)
{
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7, %8, %9, %10, %11},"
"{%12, %13, %14, %15},"
"{%16, %17, %18, %19};\n"
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
"d"(a4), "d"(a5), "d"(a6), "d"(a7),
"d"(b0), "d"(b1), "d"(b2), "d"(b3),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x4 TN
struct SM90_16x8x4_C64C64C64C64_TN
{
using DRegisters = complex<double>[4];
using ARegisters = complex<double>[2];
using BRegisters = complex<double>[1];
using CRegisters = complex<double>[4];
CUTE_HOST_DEVICE static void
fma(complex<double> & d0, complex<double> & d1,
complex<double> & d2, complex<double> & d3,
complex<double> const& a0, complex<double> const& a1,
complex<double> const& b0,
complex<double> const& c0, complex<double> const& c1,
complex<double> const& c2, complex<double> const& c3)
{
// Because thrust::complex does not provide a mutable ref
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
// d.real() = a.real() * b.real() + c.real();
SM90_16x8x4_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
a0.real(), a1.real(),
b0.real(),
c0.real(), c1.real(), c2.real(), c3.real());
// d.imag() = a.imag() * b.real() + c.imag();
SM90_16x8x4_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.imag(), a1.imag(),
b0.real(),
c0.imag(), c1.imag(), c2.imag(), c3.imag());
// d.real() = -a.imag() * b.imag() + d.real();
SM90_16x8x4_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
-a0.imag(), -a1.imag(),
b0.imag(),
d0.real(), d1.real(), d2.real(), d3.real());
// d.imag() = a.real() * b.imag() + d.imag();
SM90_16x8x4_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.real(), a1.real(),
b0.imag(),
d0.imag(), d1.imag(), d2.imag(), d3.imag());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x8 TN
struct SM90_16x8x8_C64C64C64C64_TN
{
using DRegisters = complex<double>[4];
using ARegisters = complex<double>[4];
using BRegisters = complex<double>[2];
using CRegisters = complex<double>[4];
CUTE_HOST_DEVICE static void
fma(complex<double> & d0, complex<double> & d1,
complex<double> & d2, complex<double> & d3,
complex<double> const& a0, complex<double> const& a1,
complex<double> const& a2, complex<double> const& a3,
complex<double> const& b0, complex<double> const& b1,
complex<double> const& c0, complex<double> const& c1,
complex<double> const& c2, complex<double> const& c3)
{
// Because thrust::complex does not provide a mutable ref
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
// d.real() = a.real() * b.real() + c.real();
SM90_16x8x8_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
a0.real(), a1.real(), a2.real(), a3.real(),
b0.real(), b1.real(),
c0.real(), c1.real(), c2.real(), c3.real());
// d.imag() = a.imag() * b.real() + c.imag();
SM90_16x8x8_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
b0.real(), b1.real(),
c0.imag(), c1.imag(), c2.imag(), c3.imag());
// d.real() = -a.imag() * b.imag() + d.real();
SM90_16x8x8_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
b0.imag(), b1.imag(),
d0.real(), d1.real(), d2.real(), d3.real());
// d.imag() = a.real() * b.imag() + d.imag();
SM90_16x8x8_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.real(), a1.real(), a2.real(), a3.real(),
b0.imag(), b1.imag(),
d0.imag(), d1.imag(), d2.imag(), d3.imag());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x16 TN
struct SM90_16x8x16_C64C64C64C64_TN
{
using DRegisters = complex<double>[4];
using ARegisters = complex<double>[8];
using BRegisters = complex<double>[4];
using CRegisters = complex<double>[4];
CUTE_HOST_DEVICE static void
fma(complex<double> & d0, complex<double> & d1,
complex<double> & d2, complex<double> & d3,
complex<double> const& a0, complex<double> const& a1,
complex<double> const& a2, complex<double> const& a3,
complex<double> const& a4, complex<double> const& a5,
complex<double> const& a6, complex<double> const& a7,
complex<double> const& b0, complex<double> const& b1,
complex<double> const& b2, complex<double> const& b3,
complex<double> const& c0, complex<double> const& c1,
complex<double> const& c2, complex<double> const& c3)
{
// Because thrust::complex does not provide a mutable ref
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
// d.real() = a.real() * b.real() + c.real();
SM90_16x8x16_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
a0.real(), a1.real(), a2.real(), a3.real(),
a4.real(), a5.real(), a6.real(), a7.real(),
b0.real(), b1.real(), b2.real(), b3.real(),
c0.real(), c1.real(), c2.real(), c3.real());
// d.imag() = a.imag() * b.real() + c.imag();
SM90_16x8x16_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
a4.imag(), a5.imag(), a6.imag(), a7.imag(),
b0.real(), b1.real(), b2.real(), b3.real(),
c0.imag(), c1.imag(), c2.imag(), c3.imag());
// d.real() = -a.imag() * b.imag() + d.real();
SM90_16x8x16_F64F64F64F64_TN::fma(
rd0, rd1, rd2, rd3,
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
-a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(),
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
d0.real(), d1.real(), d2.real(), d3.real());
// d.imag() = a.real() * b.imag() + d.imag();
SM90_16x8x16_F64F64F64F64_TN::fma(
id0, id1, id2, id3,
a0.real(), a1.real(), a2.real(), a3.real(),
a4.real(), a5.real(), a6.real(), a7.real(),
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
d0.imag(), d1.imag(), d2.imag(), d3.imag());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////
#include <cute/arch/mma_sm90_desc.hpp>
#include <cute/arch/mma_sm90_gmma.hpp>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
namespace GMMA {
template<
class ElementA,
class ElementB,
class ElementC,
class TileShape_MNK,
GMMA::Major MajorA = GMMA::Major::K,
GMMA::Major MajorB = GMMA::Major::K,
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
// But most commonly leave empty for defaults
>
CUTE_HOST_DEVICE constexpr
auto
ss_op_selector()
{
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
auto Tile_N = size<1>(TileShape_MNK{});
// FP16 accumulator
if constexpr (std::is_same_v<ElementC, half_t>) {
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
// Dispatch against the Tile N mode size
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// FP32 accumulator
else if constexpr (std::is_same_v<ElementC, float>) {
// FP16 inputs
if constexpr (std::is_same_v<ElementA, half_t>) {
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// BF16 inputs
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// TF32 inputs
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x8_F32TF32TF32_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x8_F32TF32TF32_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
else {
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
}
}
// S32 accumulator
else if constexpr (std::is_same_v<ElementC, int32_t>) {
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
// ElementA == int8_t && ElementB == int8_t
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32S8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32S8S8_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == int8_t && ElementB == uint8_t
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32S8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32S8U8_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == uint8_t && ElementB == int8_t
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32U8S8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32U8S8_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == uint8_t && ElementB == uint8_t
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32U8U8_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32U8U8_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
}
// Unknown accumulator type
else {
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
}
}
template<
class ElementA,
class ElementB,
class ElementC,
class TileShape_MNK,
GMMA::Major MajorA = GMMA::Major::K,
GMMA::Major MajorB = GMMA::Major::K,
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
// But most commonly leave empty for defaults
>
CUTE_HOST_DEVICE constexpr
auto
rs_op_selector()
{
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout.");
auto Tile_N = size<1>(TileShape_MNK{});
// FP16 accumulator
if constexpr (std::is_same_v<ElementC, half_t>) {
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
// Dispatch against the Tile N mode size
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// FP32 accumulator
else if constexpr (std::is_same_v<ElementC, float>) {
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
// FP16 inputs
if constexpr (std::is_same_v<ElementA, half_t>) {
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// BF16 inputs
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// TF32 inputs
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x8_F32TF32TF32_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x8_F32TF32TF32_RS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
else {
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
}
}
// S32 accumulator
else if constexpr (std::is_same_v<ElementC, int32_t>) {
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
// ElementA == int8_t && ElementB == int8_t
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32S8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32S8S8_RS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == int8_t && ElementB == uint8_t
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32S8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32S8U8_RS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == uint8_t && ElementB == int8_t
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32U8S8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32U8S8_RS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// ElementA == uint8_t && ElementB == uint8_t
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_S32U8U8_RS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_S32U8U8_RS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
}
// Unknown accumulator type
else {
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
}
}
} // end namespace GMMA
} // end namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,131 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
# define CUTE_ARCH_MMA_SM90_ENABLED
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
////////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA Descriptor and utilities
// GMMA enums and utilities
namespace GMMA
{
enum class LayoutType : uint8_t {
INTERLEAVE = 0,
B128 = 1,
B64 = 2,
B32 = 3,
};
CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) {
switch (t) {
case LayoutType::INTERLEAVE: return "INTERLEAVE";
case LayoutType::B128: return "B128";
case LayoutType::B64: return "B64";
case LayoutType::B32: return "B32";
}
return nullptr;
}
// Output operator for all enums in this namespace
CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) {
char const* s = to_string(t);
if (s) {
std::operator<<(os, s); // Explicit call to avoid ambiguity
} else {
os.setstate(std::ios_base::failbit);
}
return os;
}
} // end namespace GMMA
union GmmaDescriptor
{
uint64_t desc_;
uint32_t reg32_[2];
uint16_t reg16_[4];
// Bitfield implementation avoids the need for shifts in assignment
struct {
// start_address, bit [0,14), 4LSB not included
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
// For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED
// Unused for all SWIZZLE_* layouts (and assumed to be 1)
// For T: This is the stride from the first 8 rows to the next 8 rows.
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
// For N: This is the stride from the first 8 rows to the next 8 rows.
// For T: This is the stride fro mthe first 8 cols to the next 8 cols.
uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// base_offset, bit [49,52)
// Valid only for SWIZZLE_128B and SWIZZLE_64B
uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused
// layout type, bit [62,64)
// SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8)
};
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr
operator uint64_t() const noexcept { return desc_; }
// Printer
CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t)
{
printf("GmmaDescriptor: 0x%016lx\n", t.desc_);
printf(" start_addr : 0x%04x\n", t.start_address_);
printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_);
printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_);
printf(" base_offset: 0x%01x\n", t.base_offset_);
printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast<GMMA::LayoutType>(t.layout_type_)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

178
include/cute/arch/util.hpp Normal file
View File

@ -0,0 +1,178 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integer_sequence.hpp>
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern "C" {
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly.
CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*);
}
#endif
namespace cute
{
/// CUTE helper to cast SMEM pointer to unsigned
CUTE_HOST_DEVICE
uint32_t
cast_smem_ptr_to_uint(void const* const ptr)
{
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
// the previous internal intrinsics if they are available.
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
// in inline PTX.
//
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
//
//__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTE helper to get SMEM pointer
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
return __nvvm_get_smem_pointer(ptr);
#elif defined(__CUDA_ARCH__)
uint32_t smem_ptr;
asm(
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_ptr) : "l"(ptr));
return smem_ptr;
#else
(void) ptr;
printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n");
return 0;
#endif
}
//
// Utility for pointer interfaces
//
namespace detail {
template <class Fn,
class PtrS, int... Is,
class PtrD, int... Id>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn,
PtrS&& s, int_sequence<Is...>,
PtrD&& d, int_sequence<Id...>)
{
return fn(s[Is]..., d[Id]...);
}
template <class Fn,
class PtrA, int... Ia,
class PtrB, int... Ib,
class PtrC, int... Ic>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn,
PtrA&& a, int_sequence<Ia...>,
PtrB&& b, int_sequence<Ib...>,
PtrC&& c, int_sequence<Ic...>)
{
return fn(a[Ia]..., b[Ib]..., c[Ic]...);
}
template <class Fn,
class PtrD, int... Id,
class PtrA, int... Ia,
class PtrB, int... Ib,
class PtrC, int... Ic>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn,
PtrD&& d, int_sequence<Id...>,
PtrA&& a, int_sequence<Ia...>,
PtrB&& b, int_sequence<Ib...>,
PtrC&& c, int_sequence<Ic...>)
{
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...);
}
} // end namespace detail
template <int SRegCount, int DRegCount,
class Fn, class PtrS, class PtrD>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn, PtrS&& s, PtrD&& d)
{
return detail::explode(fn,
s, make_int_sequence<SRegCount>{},
d, make_int_sequence<DRegCount>{});
}
template <int ARegCount, int BRegCount, int CRegCount,
class Fn, class PtrA, class PtrB, class PtrC>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c)
{
return detail::explode(fn,
a, make_int_sequence<ARegCount>{},
b, make_int_sequence<BRegCount>{},
c, make_int_sequence<CRegCount>{});
}
template <int DRegCount, int ARegCount, int BRegCount, int CRegCount,
class Fn, class PtrD, class PtrA, class PtrB, class PtrC>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c)
{
return detail::explode(fn,
d, make_int_sequence<DRegCount>{},
a, make_int_sequence<ARegCount>{},
b, make_int_sequence<BRegCount>{},
c, make_int_sequence<CRegCount>{});
}
} // end namespace cute

View File

@ -0,0 +1,671 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <type_traits>
#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/tensor.hpp>
namespace cute {
// Generic copy_unpack for any Copy_Traits
template <class Operation, class... Args,
class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE constexpr
void
copy_unpack(Copy_Traits<Operation, Args...> const&,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
// Specializations can generalize on these checks
//static_assert(is_smem<TS>::value, "Expected smem for this Copy_Traits<Operation>");
//static_assert(is_rmem<TD>::value, "Expected rmem for this Copy_Traits<Operation>");
using RegistersSrc = typename Operation::SRegisters;
using RegistersDst = typename Operation::DRegisters;
using RegTypeSrc = typename std::remove_extent<RegistersSrc>::type;
using RegTypeDst = typename std::remove_extent<RegistersDst>::type;
constexpr int RegNumSrc = std::extent<RegistersSrc>::value;
constexpr int RegNumDst = std::extent<RegistersDst>::value;
Tensor rS = recast<RegTypeSrc>(src);
Tensor rD = recast<RegTypeDst>(dst);
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
"In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy.");
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
"In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy.");
detail::explode(Operation::copy,
rS, make_int_sequence<RegNumSrc>{},
rD, make_int_sequence<RegNumDst>{});
}
template <class... Args>
struct Copy_Atom;
template <class CopyOperation, class T>
struct Copy_Atom<CopyOperation, T> : Copy_Atom<Copy_Traits<CopyOperation>, T>
{};
template <class... Args, class T>
struct Copy_Atom<Copy_Traits<Args...>, T>
: Copy_Traits<Args...>
{
using Traits = Copy_Traits<Args...>;
// Bit and Thr layouts from the Copy_Traits
using ThrID = typename Traits::ThrID;
using BitLayoutSrc = typename Traits::SrcLayout;
using BitLayoutDst = typename Traits::DstLayout;
using BitLayoutRef = typename Traits::RefLayout;
using ValType = T;
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
using ValLayoutRef = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutRef{}));
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType.");
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType.");
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType.");
static constexpr int NumValSrc = size<1>(ValLayoutSrc{});
static constexpr int NumValDst = size<1>(ValLayoutDst{});
// Additional Trait parameters/transformations
template <class... TraitsArgs>
CUTE_HOST_DEVICE
auto
with(TraitsArgs&&... args) const {
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
return Copy_Atom<decltype(traits), T>{traits};
}
// Print thread and data layouts for debugging
CUTE_HOST_DEVICE static
void
print_all()
{
print("ThrID: "); print(ThrID{}); print("\n");
print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n");
print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n");
print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n");
print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n");
print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n");
print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n");
print("ValueType: %db", sizeof_bits<ValType>::value); print("\n");
}
//
// Tensor call interfaces
//
// Cast, check, and call
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE
void
call(Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) const
{
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value || is_constant<NumValDst, decltype(size(dst))>::value) {
// Dispatch to unpack for instruction
return copy_unpack(*this, src, dst);
} else {
// Recurse if needed by peeling the tensor mode
return copy(*this, tensor<0>(src), tensor<0>(dst));
}
}
// Accept mutable temporaries
template <class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE
void
call(Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> && dst) const
{
return call(src, dst);
}
};
//
// A tiling of copy atoms
//
template <class Copy_Atom,
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
class ShapeTile_MN> // coord space
struct TiledCopy : Copy_Atom
{
// Layout information from the CopyAtom
using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx
using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset
using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset
using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset
using AtomNumThr = decltype(size<0>(AtomLayoutRef{}));
using AtomNumVal = decltype(size<1>(AtomLayoutRef{}));
// Layout information for the TiledCopy
using Tiler_MN = ShapeTile_MN;
using TiledShape_MN = decltype(shape(ShapeTile_MN{}));
using TiledLayout_TV = LayoutCopy_TV;
using TiledNumThr = decltype(size<0>(TiledLayout_TV{}));
using TiledNumVal = decltype(size<1>(TiledLayout_TV{}));
CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom");
CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom");
// Tile a tensor or a layout from shape
// (M,N,...)
// to shape
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
// where
// ThrV: The threads local to a COPY_ATOM Src.
// ThrX: The threads tiled across COPY_ATOMs Src.
// FrgV: The values local to a COPY_ATOM Src.
// RestM: The values tiled in M.
// RestN: The values tiled in N.
template <class STensor>
CUTE_HOST_DEVICE constexpr static
auto
tidfrg_S(STensor&& stensor)
{
return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}));
}
// Tile a tensor or a layout from shape
// (M,N,...)
// to shape
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
// where
// ThrV: The threads local to a COPY_ATOM Dst.
// ThrX: The threads tiled across COPY_ATOMs Dst.
// FrgV: The values local to a COPY_ATOM Dst.
// RestM: The values tiled in M.
// RestN: The values tiled in N.
template <class DTensor>
CUTE_HOST_DEVICE constexpr static
auto
tidfrg_D(DTensor&& dtensor)
{
return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}));
}
template <class Tensor, class Ref2TrgLayout>
CUTE_HOST_DEVICE constexpr static
auto
thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg)
{
constexpr int R = remove_cvref_t<Tensor>::rank;
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
// Generalize the dimension checks for arbitrary rank
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
// Take the thrs/vals that the atom is interested in
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID
auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{}));
// ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n)
// Transform to the trg layout
auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _);
// ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n)
// Transform the thrs mode from thrid to thr_idx
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID
auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{});
// ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n)
/// ==================
// Tile the tensor for TiledLayout
auto t_tensor = zipped_divide(tensor, Tiler_MN{});
// ((TileM,TileN,...),(RestM,RestN,...))
// Transform the tile mode
auto tv_tensor = t_tensor.compose(thrval2mn, _);
// ((thrid,val),(RM,RN,...))
// Unfold and return
return tv_tensor(make_coord(_,_), _);
}
// retile_S and retile_D assume they are working with the reference layout -- they are the same
template <class Tensor>
CUTE_HOST_DEVICE constexpr static
auto
retile(Tensor&& tensor)
{
constexpr int R = remove_cvref_t<Tensor>::rank;
// Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation
// Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV.
// Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV
// and that shape is what we gather from the other modes of tensor
auto V = size<0>(tensor);
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{}));
// (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV
auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{}));
// (atom_vals,rest_vals) -> (v,m,n)
/// =======
// Tile the tensor for TileFrg
auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V));
// ((TileV,TileM,TileN,...),(1,RestM,RestN,...))
// Transform the tile mode
auto v_tensor = t_tensor.compose(frg_layout_v, _);
// ((atom_vals,rest_vals),(1,RM,RN,...))
// Unfold and return
return v_tensor(_, append<R>(Int<0>{},_));
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutS_MN()
{
// (M,N) -> (M,N)
auto ref_S = make_layout(TiledShape_MN{});
// (thr_idx,val_idx) -> (M,N)
auto layoutS_TV = tidfrg_S(ref_S);
// (M,K) -> (thr_idx,val_idx)
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S));
// athrid = (v,m,k) -> thr_idx
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
return cute::make_tuple(layoutS_MK, thrID_S);
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutS_TV()
{
// (M,N) -> (M,N)
auto ref_S = make_layout(TiledShape_MN{});
// (thr_idx,val_idx) -> (M,N)
return tidfrg_S(ref_S)(_,_,Int<0>{});
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutD_MN()
{
// (M,N) -> (M,N)
auto ref_D = make_layout(TiledShape_MN{});
// (thr_idx,val_idx) -> (M,N)
auto layoutD_TV = tidfrg_D(ref_D);
// (M,K) -> (thr_idx,val_idx)
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D));
// athrid = (v,m,k) -> thr_idx
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
return cute::make_tuple(layoutD_MK, thrID_D);
}
CUTE_HOST_DEVICE constexpr static
auto
get_layoutD_TV()
{
// (M,N) -> (M,N)
auto ref_D = make_layout(TiledShape_MN{});
// (thr_idx,val_idx) -> (M,N)
return tidfrg_D(ref_D)(_,_,Int<0>{});
}
template <class ThrIdx>
struct ThrCopy : Copy_Atom
{
ThrIdx thr_idx_;
CUTE_HOST_DEVICE
ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {}
template <class STensor>
CUTE_HOST_DEVICE
auto
partition_S(STensor&& stensor) {
//static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename Copy_Atom::ValType),
// "Expected ValType for tiling SrcTensor.");
auto thr_tensor = make_tensor(std::forward<STensor>(stensor).data(), tidfrg_S(stensor.layout()));
return thr_tensor(thr_idx_, _, repeat<rank_v<STensor>>(_));
}
template <class DTensor>
CUTE_HOST_DEVICE
auto
partition_D(DTensor&& dtensor) {
//static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename Copy_Atom::ValType),
// "Expected ValType for tiling DstTensor.");
auto thr_tensor = make_tensor(std::forward<DTensor>(dtensor).data(), tidfrg_D(dtensor.layout()));
return thr_tensor(thr_idx_, _, repeat<rank_v<DTensor>>(_));
}
template <class STensor>
CUTE_HOST_DEVICE static
auto
retile_S(STensor&& stensor) {
static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename Copy_Atom::ValType),
"Expected ValType for tiling SrcTensor.");
return make_tensor(std::forward<STensor>(stensor).data(), TiledCopy::retile(stensor.layout()));
}
template <class DTensor>
CUTE_HOST_DEVICE static
auto
retile_D(DTensor&& dtensor) {
static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename Copy_Atom::ValType),
"Expected ValType for tiling DstTensor.");
return make_tensor(std::forward<DTensor>(dtensor).data(), TiledCopy::retile(dtensor.layout()));
}
};
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE static
auto
get_slice(ThrIdx const& thr_idx)
{
return ThrCopy<ThrIdx>(thr_idx);
}
template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE static
auto
get_thread_slice(ThrIdx const& thr_idx)
{
return get_slice(thr_idx);
}
};
template <class... Args,
class LayoutCopy_TV,
class... TLayout>
CUTE_HOST_DEVICE
auto
make_tiled_copy_impl(Copy_Atom<Args...> const& atom,
LayoutCopy_TV const&,
Tile<TLayout...> const&)
{
return TiledCopy<Copy_Atom<Args...>, LayoutCopy_TV, Tile<TLayout...>>{atom};
}
//
// These tile the Copy_Atom as a whole
//
template <class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{})));
}
template <class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_B(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{})));
}
template <class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma)
{
using MNK = typename TiledMMA::TiledShape_MNK;
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{})));
}
template <class... Args,
class ThrLayout,
class ValLayout = Layout<_1>>
CUTE_HOST_DEVICE
auto
make_tiled_copy(Copy_Atom<Args...> const& copy_atom,
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx
ValLayout const& val_layout = {})
{
constexpr int R = cute::max(rank_v<ThrLayout>, rank_v<ValLayout>);
auto thr_layout_mn = append<R>(thr_layout, Layout<_1>{});
auto val_layout_mn = append<R>(val_layout, Layout<_1>{});
// Take the raked_products to compute the Layout_MN
auto layout_mn = raked_product(thr_layout_mn, val_layout_mn);
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
//print("thr_layout: "); print(thr_layout_mn); print("\n");
//print("val_layout: "); print(val_layout_mn); print("\n");
//print("layout_mn : "); print(layout_mn); print("\n");
//print("layout_tv : "); print(layout_tv); print("\n");
return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn)));
}
// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy
template <class... Args,
class TiledCopy>
CUTE_HOST_DEVICE
auto
make_tiled_copy_S(Copy_Atom<Args...> const& copy_atom,
TiledCopy const& tiled_copy)
{
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{});
}
// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy
template <class... Args,
class TiledCopy>
CUTE_HOST_DEVICE
auto
make_tiled_copy_D(Copy_Atom<Args...> const& copy_atom,
TiledCopy const& tiled_copy)
{
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{});
}
//
// Size
//
// The logical size of a TileCopy
template <int... I, class... Args>
CUTE_HOST_DEVICE constexpr
auto
tile_size(TiledCopy<Args...> const&)
{
return size<I...>(typename TiledCopy<Args...>::TiledShape_MN{});
}
// The number of threads involved in a TiledCopy
template <class... Args>
CUTE_HOST_DEVICE constexpr
auto
size(TiledCopy<Args...> const&)
{
return typename TiledCopy<Args...>::TiledNumThr{};
}
//
// Display utilities
//
template <class... Args>
CUTE_HOST_DEVICE
auto
print_latex(TiledCopy<Args...> const& copy)
{
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN();
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN();
print_latex_copy(layoutS_MN, thrID_S,
layoutD_MN, thrID_D);
}
// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread
template <class LayoutS, class ThrIDS,
class LayoutD, class ThrIDD>
CUTE_HOST_DEVICE
void
print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx
{
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
assert(size<0>(S) == size<0>(D));
assert(size<1>(S) == size<1>(D));
char const* latex_header =
"\\documentclass{standalone}\n"
"\\usepackage{tikz}\n"
"\\usetikzlibrary{external}\n"
"\\tikzexternalize\n"
"\\begin{document}\n"
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n";
char const* latex_footer =
"\\end{tikzpicture}\n"
"\\end{document}\n";
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}",};
// Header
printf("%% LayoutS: "); print(S); printf("\n");
printf("%% ThrIDS : "); print(TS); printf("\n");
printf("%% LayoutD: "); print(D); printf("\n");
printf("%% ThrIDD : "); print(TD); printf("\n\n");
printf(latex_header);
// S starting at 0,0
for (int i = 0; i < size<0>(S); ++i) {
for (int j = 0; j < size<1>(S); ++j) {
int thrid = S(i,j) % size(TS);
int val_idx = S(i,j) / size(TS);
int thr_idx = TS(thrid);
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color_map[thr_idx % 8],
i, j,
thr_idx, val_idx);
}
}
// D starting at 0,size<1>(S)+3
for (int i = 0; i < size<0>(D); ++i) {
for (int j = 0; j < size<1>(D); ++j) {
int thrid = D(i,j) % size(TD);
int val_idx = D(i,j) / size(TD);
int thr_idx = TD(thrid);
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
color_map[thr_idx % 8],
i, j + size<1>(S) + 3,
thr_idx, val_idx);
}
}
// S Labels
for (int i = 0, j = -1; i < size<0>(S); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
}
for (int j = 0, i = -1; j < size<1>(S); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
}
// D Labels
for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i);
}
for (int j = 0, i = -1; j < size<1>(D); ++j) {
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j);
}
// Footer
printf(latex_footer);
}
} // end namespace cute
////////////////////////////////////////////////////////////////////////////////////////////////////
#include <cute/atom/copy_traits.hpp>
#include <cute/atom/copy_traits_sm75.hpp>
#include <cute/atom/copy_traits_sm80.hpp>
#include <cute/atom/copy_traits_sm90.hpp>
// Config
#if (__CUDACC_VER_MAJOR__ >= 12)
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
#endif
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
#include <cute/atom/copy_traits_sm90_tma.hpp>
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,76 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/copy.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <class CopyOperation, class... CopyOpArgs>
struct Copy_Traits
{
static_assert(sizeof(CopyOperation) == 0, "Copy_Traits not implemented for this Copy_Operation.");
};
template <class S, class D>
struct Copy_Traits<UniversalCopy<S,D>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<DefaultCopy>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
} // end namespace cute

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/copy_sm75.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <>
struct Copy_Traits<SM75_U32x1_LDSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <Shape < _8,_4>,_128>,
Stride<Stride<_128,_0>, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_32,_32>,
Stride<_32, _1>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
template <>
struct Copy_Traits<SM75_U32x2_LDSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <Shape < _16,_2>,_128>,
Stride<Stride<_128,_0>, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_32,Shape <_32, _2>>,
Stride<_32,Stride< _1,_1024>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
template <>
struct Copy_Traits<SM75_U32x4_LDSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape < _32,_128>,
Stride<_128, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_32,Shape <_32, _4>>,
Stride<_32,Stride< _1,_1024>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
template <>
struct Copy_Traits<SM75_U16x2_LDSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <Shape < _8,_4>,_128>,
Stride<Stride<_128,_0>, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2>>,
Stride<Stride<_256,_16>,Stride< _1,_128>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
template <>
struct Copy_Traits<SM75_U16x4_LDSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <Shape < _16,_2>,_128>,
Stride<Stride<_128,_0>, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2, _2>>,
Stride<Stride<_256,_16>,Stride< _1,_128,_1024>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
template <>
struct Copy_Traits<SM75_U16x8_LDSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape < _32,_128>,
Stride<_128, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2, _4>>,
Stride<Stride<_256,_16>,Stride< _1,_128,_1024>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};
} // end namespace cute

View File

@ -0,0 +1,98 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/copy_sm80.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <class S, class D>
struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<S,D>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <class S, class D>
struct Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL<S,D>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Element copy selector
template <class SrcTensor, class DstTensor>
CUTE_HOST_DEVICE constexpr
auto
select_elementwise_copy(SrcTensor const&, DstTensor const&)
{
using SrcType = typename SrcTensor::value_type;
using DstType = typename DstTensor::value_type;
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
if constexpr (is_gmem<SrcTensor>::value && is_smem<DstTensor>::value &&
sizeof(SrcType) == sizeof(DstType) &&
(sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16))
{
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
} else {
return UniversalCopy<SrcType,DstType>{};
}
CUTE_GCC_UNREACHABLE;
#else
return UniversalCopy<SrcType,DstType>{};
#endif
}
}

View File

@ -0,0 +1,132 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/copy_sm90.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/atom/copy_traits_sm75.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <>
struct Copy_Traits<SM90_U32x1_STSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U32x1_LDSM_N>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U32x1_LDSM_N>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<SM90_U32x2_STSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U32x2_LDSM_N>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U32x2_LDSM_N>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<SM90_U32x4_STSM_N>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U32x4_LDSM_N>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U32x4_LDSM_N>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<SM90_U16x2_STSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U16x2_LDSM_T>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U16x2_LDSM_T>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<SM90_U16x4_STSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U16x4_LDSM_T>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U16x4_LDSM_T>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
template <>
struct Copy_Traits<SM90_U16x8_STSM_T>
{
// Logical thread id to thread idx (warp)
using ThrID = Layout<_32>;
// Map from (src-thr,src-val) to bit
using SrcLayout = typename Copy_Traits<SM75_U16x8_LDSM_T>::DstLayout;
// Map from (dst-thr,dst-val) to bit
using DstLayout = typename Copy_Traits<SM75_U16x8_LDSM_T>::SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
} // end namespace cute

View File

@ -0,0 +1,795 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda.h>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD ///////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {};
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
template <class NumBits>
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBits>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
TmaDescriptor const& tma_desc_;
uint64_t& tma_load_mbar_;
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const dst_ptr,
Coord const& src_coord, seq<Is...>) const
{
#if 0
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z);
print(" TMA Coord "); print(src_coord); print("\n");
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
uint64_t(tma_desc_.size1_),
uint64_t(tma_desc_.size2_),
uint64_t(tma_desc_.size3_))); print("\n");
#endif
SM90_TMA_LOAD::copy(&tma_desc_,
tma_load_mbar_,
dst_ptr,
get<Is>(src_coord)...);
}
// This is the copy_unpack dispatch for this Copy_Traits
// Src needs to be a gmem tensor with TmaCoordIterator .data()
// Dst needs to be a smem tensor
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
//static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD");
traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
}
};
// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBits, class GmemStrides>
struct Copy_Traits<SM90_TMA_LOAD, NumBits, GmemStrides>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
TmaDescriptor tma_desc_;
GmemStrides g_stride_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM90_TMA_LOAD with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_OP, NumBits>
with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
// assert(multicast_mask == 0);
(void) multicast_mask;
return {tma_desc_, tma_mbar};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
g_shape,
g_stride_);
}
// Don't try to execute a copy with SM90_TMA_LOAD before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {};
template <class NumBits>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBits>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
TmaDescriptor const& tma_desc_;
uint64_t& tma_load_mbar_;
uint16_t const& multicast_mask_;
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const dst_ptr,
Coord const& src_coord, seq<Is...>) const
{
#if 0
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z);
print(" TMA Coord "); print(src_coord); print("\n");
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
uint64_t(tma_desc_.size1_),
uint64_t(tma_desc_.size2_),
uint64_t(tma_desc_.size3_))); print("\n");
#endif
SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_,
tma_load_mbar_,
multicast_mask_,
dst_ptr,
get<Is>(src_coord)...);
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
//static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST");
traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
}
};
template <class NumBits, class GmemStrides>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBits, GmemStrides>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
TmaDescriptor tma_desc_;
GmemStrides g_stride_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBits>
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {tma_desc_, tma_load_mbar, multicast_mask};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
g_shape,
g_stride_);
}
// Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_STORE //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// The executable SM90_TMA_STORE with tma_desc
template <class NumBits, class GmemStrides>
struct Copy_Traits<SM90_TMA_STORE, NumBits, GmemStrides>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_STORE arguments
TmaDescriptor tma_desc_;
GmemStrides g_stride_;
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
g_shape,
g_stride_);
}
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const src_ptr,
Coord const& dst_coord, seq<Is...>) const
{
#if 0
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z);
print(" TMA Coord "); print(dst_coord); print("\n");
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
uint64_t(tma_desc_.size1_),
uint64_t(tma_desc_.size2_),
uint64_t(tma_desc_.size3_))); print("\n");
#endif
SM90_TMA_STORE::copy(&tma_desc_,
src_ptr,
get<Is>(dst_coord)...);
}
// This is the copy_unpack dispatch for this Copy_Traits
// Src needs to be a smem tensor
// Dst needs to be a gmem tensor with TmaCoordIterator .data()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq<decltype(dst.data().coord_)>{});
}
};
//
// MAKE_TMA_COPY and related
//
template <int B, int M, int S, class Offset, class SLayout>
TMA::SmemSwizzleBits
get_tma_swizzle_bits(ComposedLayout<Swizzle<B,M,S>,Offset,SLayout>)
{
static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle.");
static_assert(S == 3, "Unsupported layout swizzle");
switch (B) {
default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle.");
case 3: return TMA::SmemSwizzleBits::B128;
case 2: return TMA::SmemSwizzleBits::B64;
case 1: return TMA::SmemSwizzleBits::B32;
case 0: return TMA::SmemSwizzleBits::DISABLE;
}
}
template <class Shape, class Stride>
TMA::SmemSwizzleBits
get_tma_swizzle_bits(Layout<Shape,Stride>)
{
return TMA::SmemSwizzleBits::DISABLE;
}
template <int B, int M, int S, class Offset, class SLayout>
auto
get_nonswizzle_layout(ComposedLayout<Swizzle<B,M,S>,Offset,SLayout> const& slayout)
{
return slayout.layout_fn();
}
template <class Shape, class Stride>
auto
get_nonswizzle_layout(Layout<Shape,Stride> const& slayout)
{
return slayout;
}
/** Make a CuTe CTA-collective TiledCopy for a TMA operation.
*
* @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE
* @param gtensor The GMEM Tensor to be involved in the TMA.
* @param slayout The SMEM Layout to be involved in the TMA.
* @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with.
* This is often the blk_shape that is used to tile the GMEM for CTAs:
* local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor
* @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16
* defining the multicast size (used to further partition the SMEM)
* Else, static-1
*
* This code attempts to maximize the TMA box size. It does this by tracing
* the SMEM "vector" -- the inverse of the smem layout -- to find the largest
* contiguous array of smem that can be written to/from global memory given
* the constraints that the TMA instruction imposes.
*
* This is accomplished by assigning "basis" strides to the GMEM to track which
* modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according
* to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc.
*
* Examples:
using T = float;
T* gptr = nullptr;
{
// Simple 2D
Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM
auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
}
{
// GMMA 2D
Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM
auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom<T>{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
}
{
// 3D
Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM
auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
}
{
// cuTENSOR 4D
auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM
auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling:
// Take 128-elem from m: m0 must divide 128,
// m-last may be predicated
// Take 32-elem from k0, 2-elem from k1
auto slayout = make_layout(cta_tile); // Col-Major SMEM
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{});
}
*
* Check the TMA box size and desc:
print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n");
print("TMA desc : "); print(tma.tma_desc_); print("\n");
*
* Usage:
Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor
Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA
Tensor sA = make_tensor(make_smem_ptr<T>(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor
auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning
Tensor tAgA = cta_tma.partition_S(gA); // Partition for src
Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst
copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params
*/
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
class CTA_Tile,
class Cluster_Size>
CUTE_HOST
auto
make_tma_copy(CopyOp,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
CTA_Tile const& cta_tile,
Cluster_Size const& cluster_size)
{
static_assert((std::is_same<CopyOp, SM90_TMA_LOAD>::value && is_constant<1, Cluster_Size>::value) ||
(std::is_same<CopyOp, SM90_TMA_LOAD_MULTICAST>::value) ||
(std::is_same<CopyOp, SM90_TMA_STORE>::value && is_constant<1, Cluster_Size>::value));
using T = typename Tensor<GEngine,GLayout>::value_type;
//
// TMA parameter checking
//
auto flat_glayout = flatten(gtensor.layout());
CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{},
"CTA_Tile cannot have more than five modes, TMA arch restriction.");
CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{},
"If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode.");
CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)),
"CTA_Tile must be compatible with SLayout.");
CUTE_STATIC_ASSERT_V(is_integral<Cluster_Size>{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{},
"Expecting a pow2 integral Cluster_Size leq 16.");
CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{},
"ClusterShape must divide domain size of slayout.");
//
// TMA slayout manipulation
//
auto tma_multimode = rank(flat_glayout) > Int<5>{};
// Invert the smem to get the largest contiguous vector in the smem layout
auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout));
// trunc_smem_idx -> trunc_smem_coord
// Map from smem idx to a gmem mode
auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout));
// Truncate any incompatibilities
auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){
[[maybe_unused]] auto v = basis_value(e);
return not is_constant<1,decltype(v)>{};
});
static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA.");
constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5));
// Keep only the static-1 basis modes into gmem
auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode);
// Keep only the portion each multicast CTA will be responsible for
auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size));
//
// TMA gtensor manipulation
//
// Generate a TupleBasis for the gtensor
auto flat_gbasis = make_basis_like(shape(flat_glayout));
// Fold the flat_gbasis into the glayout
auto glayout_basis = make_layout(shape(gtensor),
stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis),
make_layout(repeat_like(shape(gtensor), Int<2>{})))));
// Tile the modes of gtensor with cta_tile
auto cta_glayout_basis = composition(glayout_basis, cta_tile);
// Check that the cta_tile selects modes from gtensor properly
for_each(flatten(stride(cta_glayout_basis)), [](auto d) {
static_assert(is_constant<1, decltype(d.value())>::value,
"CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout.");
});
// Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout
auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc));
// Append any missing basis on the end as size-1 modes b/c they got truncated
auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){
auto k = find(init, e);
return remove<k>(init);
});
// The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode
auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc,
make_layout(repeat<rank(missing_basis)>(Int<1>{}), missing_basis)));
#if 0
print("g_layout : "); print(gtensor.layout()); print("\n");
print("s_layout : "); print(slayout); print("\n");
print("cta_tile : "); print(cta_tile); print("\n");
print("cluster_size : "); print(cluster_size); print("\n");
print("flat_gbasis : "); print(flat_gbasis); print("\n");
print("cta_glayout : "); print(cta_glayout_basis); print("\n");
print("inv_smem : "); print(inv_smem_layout); print("\n");
print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n");
print("missing_b : "); print(missing_basis); print("\n");
print("tma_layout_cta: "); print(tma_layout_cta); print("\n");
#endif
//
// TMA gmem desc info
//
constexpr int TmaRANK = cute::min(rank(flat_glayout), 5);
void* gmem_address = (void*) gtensor.data();
cute::array<cuuint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
cute::array<cuuint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
for_each(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
// NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below
auto e = stride<i>(tma_layout_cta);
constexpr int j = decltype(e.mode())::value;
constexpr int tma_i = i < 5 ? i : 4;
// Problem stride
uint64_t stride_j = stride<j>(flat_glayout) * sizeof(T);
uint64_t old_stride = gmem_prob_stride[tma_i];
gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j);
// Problem shape
uint64_t shape_j = shape<j>(flat_glayout);
if (gmem_prob_stride[tma_i] != 0) {
// We're "resetting" this TMA mode and using it as a "multimode"
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i])
+ (shape_j-1) * (stride_j / gmem_prob_stride[tma_i])
+ 1;
} else {
gmem_prob_shape[tma_i] = shape_j;
}
});
assert((reinterpret_cast<uint64_t>(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned
assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1
assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
//
// TMA smem desc info
//
// TMA smem box size
cute::array<cuuint32_t, 5> smem_box_shape = {1,1,1,1,1};
for_each(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
uint32_t shape_i = shape<i>(tma_layout_cta);
constexpr int tma_i = i < 5 ? i : 4;
if (tma_multimode && tma_i == 4) {
// We're "reusing" this TMA mode and using it as a "multimode"
smem_box_shape[tma_i] = 1;
} else {
smem_box_shape[tma_i] = shape_i;
}
});
// TMA smem mode strides
[[maybe_unused]] cute::array<cuuint32_t, 5> smem_box_stride = {1,1,1,1,1};
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3
assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3
assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3
assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3
assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3
//
// Construct the descriptor
//
TmaDescriptor tma_desc = {0};
#if (__CUDACC_VER_MAJOR__ >= 12)
//
// TMA general info
//
cuuint32_t tma_dim = TmaRANK;
CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType<T>();
CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;
CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
// TMA smem swizzle type
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout));
CUresult result = cuTensorMapEncodeTiled(
&tma_desc,
tma_format,
tma_dim,
gmem_address,
gmem_prob_shape.data(),
gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1
smem_box_shape.data(),
smem_box_stride.data(),
tma_interleave,
smem_swizzle,
tma_l2Promotion,
tma_oobFill);
if (result != CUDA_SUCCESS) {
std::cerr << "TMA Desc Addr: " << &tma_desc
<< "\nformat " << tma_format
<< "\ndim " << tma_dim
<< "\ngmem_address " << gmem_address
<< "\nglobalDim " << gmem_prob_shape
<< "\nglobalStrides " << gmem_prob_stride
<< "\nboxDim " << smem_box_shape
<< "\nelementStrides " << smem_box_stride
<< "\ninterleave " << tma_interleave
<< "\nswizzle " << smem_swizzle
<< "\nl2Promotion " << tma_l2Promotion
<< "\noobFill " << tma_oobFill << std::endl;
std::cerr << "Error: Failed to intialize the TMA descriptor " << result << std::endl;
assert(false);
}
#endif // (__CUDACC_VER_MAJOR__ >= 12)
//
// Construct the Copy_Traits
//
// Finally, get the inverse permutation of the E<i> bases for the mocked gmem stride
auto gmem_stride_bases_flat = transform(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
auto k = find(stride(tma_layout_cta), E<i>{});
// NOTE: gcc 7.3.5 WAR -- avoid if constexpr
int32_t tma_coord_stride = int32_t(stride<i>(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16));
return conditional_return(tma_multimode && (k >= Int<4>{}),
E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride
E<k>{});
});
// Give that the profile of gtensor and fold it
auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat),
make_layout(repeat_like(shape(gtensor), Int<2>{}))));
constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8;
using Traits = Copy_Traits<CopyOp, Int<num_bits>, decltype(gmem_stride_bases)>;
#if 0
print("num_bits : "); print(num_bits); print("\n");
print("g_stride_bases: "); print(gmem_stride_bases); print("\n");
#endif
//
// Construct the TiledCopy
//
// The ThrVal layout for 1 TMA instruction within cta_tile
auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{}));
// The ThrVal layout for N TMA instructions within cta_tile
auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size));
#if 0
print("layout_tv : "); print(layout_tv); print("\n");
#endif
return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases};
}
// Explicit defaulting
template <class CopyOp,
class GEngine, class GLayout,
class SLayout>
CUTE_HOST
auto
make_tma_copy(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout)
{
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{});
}
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
class Cluster_Size>
CUTE_HOST
auto
make_tma_copy(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
Cluster_Size const& cluster_size)
{
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size);
}
} // end namespace cute

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,70 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <class MMAOperation, class... MMAOpArgs>
struct MMA_Traits
{
static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation.");
};
template <class D, class A, class B, class C>
struct MMA_Traits<UniversalFMA<D,A,B,C>>
{
using ElementDVal = D;
using ElementAVal = A;
using ElementBVal = B;
using ElementCVal = C;
// Logical shape of the MMA
using Shape_MNK = Shape<_1,_1,_1>;
// Logical thread id (tid) -> tidx
using ThrID = Layout<_1>;
// (Logical thread id (tid), Logical value id (vid)) -> coord
// (tid,vid) -> (m,k)
using ALayout = Layout<Shape<_1,_1>>;
// (tid,vid) -> (n,k)
using BLayout = Layout<Shape<_1,_1>>;
// (tid,vid) -> (m,n)
using CLayout = Layout<Shape<_1,_1>>;
};
} // namespace cute

View File

@ -0,0 +1,73 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma_sm61.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <>
struct MMA_Traits<SM61_DP4A>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_1,_1,_4>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape<_1,_4>>;
using BLayout = Layout<Shape<_1,_4>>;
using CLayout = Layout<Shape<_1,_1>>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM61_DP2A>
{
using ElementDVal = int32_t;
using ElementAVal = int16_t;
using ElementBVal = int16_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_1,_1,_2>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape<_1,_2>>;
using BLayout = Layout<Shape<_1,_2>>;
using CLayout = Layout<Shape<_1,_1>>;
};
} // namespace cute

View File

@ -0,0 +1,198 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma_sm70.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
namespace cute
{
namespace {
// Logical thread id to thread idx (quadpair)
using SM70_QuadPair = Layout<Shape <_4, _2>,
Stride<_1,_16>>;
// (T8,V4) -> (M8,K4)
using SM70_8x4_Row = Layout<Shape <_8,_4>,
Stride<_1,_8>>;
// (T8,V4) -> (M8,K4)
using SM70_8x4_Col = Layout<Shape <Shape <_4,_2>,_4>,
Stride<Stride<_8,_4>,_1>>;
// (T8,V8) -> (M8,N8)
using SM70_8x8_16b = Layout<Shape <_8,_8>,
Stride<_1,_8>>;
// (T8,V8) -> (M8,N8)
using SM70_8x8_32b = Layout<Shape <Shape <_2, _2,_2>,Shape <_2,_2, _2>>,
Stride<Stride<_1,_16,_4>,Stride<_8,_2,_32>>>;
}
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Row;
using BLayout = SM70_8x4_Row;
using CLayout = SM70_8x8_16b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Col;
using BLayout = SM70_8x4_Col;
using CLayout = SM70_8x8_16b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Col;
using BLayout = SM70_8x4_Row;
using CLayout = SM70_8x8_16b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Row;
using BLayout = SM70_8x4_Col;
using CLayout = SM70_8x8_16b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Row;
using BLayout = SM70_8x4_Row;
using CLayout = SM70_8x8_32b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Col;
using BLayout = SM70_8x4_Col;
using CLayout = SM70_8x8_32b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Col;
using BLayout = SM70_8x4_Row;
using CLayout = SM70_8x8_32b;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TT>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = SM70_QuadPair;
using ALayout = SM70_8x4_Row;
using BLayout = SM70_8x4_Col;
using CLayout = SM70_8x8_32b;
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cute

View File

@ -0,0 +1,81 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma_sm75.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
namespace cute
{
template <>
struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_2>,Stride<_16,_1>>>;
using BLayout = Layout<Shape <Shape < _4,_8>,_2>,
Stride<Stride<_16,_1>,_8>>;
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_2>,Stride<_16,_1>>>;
};
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM75_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_8,_8,_16>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,_4>,
Stride<Stride<_32,_1>,_8>>;
using BLayout = Layout<Shape <Shape < _4,_8>,_4>,
Stride<Stride<_32,_1>,_8>>;
using CLayout = Layout<Shape <Shape < _4,_8>,_2>,
Stride<Stride<_16,_1>,_8>>;
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cute

View File

@ -0,0 +1,446 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
#include <cute/numeric/integer_subbyte.hpp>
#include <cutlass/numeric_types.h>
namespace cute
{
namespace {
// (T32,V1) -> (M8,N8)
using SM80_8x4 = Layout<Shape <Shape < _4,_8>,_1>,
Stride<Stride< _8,_1>,_0>>;
// (T32,V2) -> (M8,N8)
using SM80_8x8_Row = Layout<Shape <Shape < _4,_8>,_2>,
Stride<Stride<_16,_1>,_8>>;
// (T32,V4) -> (M8,N16)
using SM80_8x16_Row = Layout<Shape <Shape < _4,_8>,_4>,
Stride<Stride<_32,_1>,_8>>;
// (T32,V4) -> (M16,N8)
using SM80_16x8_Row = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
}
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp16 = fp16 * fp16 + fp16 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
using ALayout = SM80_16x8_Row;
using BLayout = SM80_8x8_Row;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = half_t;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = half_t;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2, _2>>,
Stride<Stride<_32,_1>,Stride<_16,_8,_128>>>;
using BLayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
using CLayout = SM80_16x8_Row;
};
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp32 = fp16 * fp16 + fp32 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN>
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
};
template <>
struct MMA_Traits<SM80_16x8x16_F32F16F16F32_TN>
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = half_t;
using ElementBVal = half_t;
using ElementCVal = float;
};
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp32 = bf16 * bf16 + fp32 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x8_F32BF16BF16F32_TN>
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = bfloat16_t;
using ElementBVal = bfloat16_t;
using ElementCVal = float;
};
template <>
struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
{
using ElementDVal = float;
using ElementAVal = bfloat16_t;
using ElementBVal = bfloat16_t;
using ElementCVal = float;
};
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp32 = tf32 * tf32 + fp32 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
{
using ElementDVal = float;
using ElementAVal = cutlass::tfloat32_t;
using ElementBVal = cutlass::tfloat32_t;
using ElementCVal = float;
using Shape_MNK = Shape<_16,_8,_4>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,_2>,
Stride<Stride<_16,_1>,_8>>;
using BLayout = SM80_8x4;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
{
using ElementDVal = float;
using ElementAVal = cutlass::tfloat32_t;
using ElementBVal = cutlass::tfloat32_t;
using ElementCVal = float;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
using BLayout = Layout<Shape <Shape <_4,_8>, _2>,
Stride<Stride<_8,_1>,_32>>;
using CLayout = SM80_16x8_Row;
};
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp64 = fp64 * fp64 + fp64 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using Shape_MNK = Shape<_8,_8,_4>;
using ThrID = Layout<_32>;
using ALayout = SM80_8x4;
using BLayout = SM80_8x4;
using CLayout = SM80_8x8_Row;
};
// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts
template <>
struct MMA_Traits<SM80_8x8x4_C64C64C64C64_TN>
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
};
// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts
template <>
struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s8 * s8 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_8,_8,_16>;
using ThrID = Layout<_32>;
using ALayout = SM80_8x16_Row;
using BLayout = SM80_8x16_Row;
using CLayout = SM80_8x8_Row;
};
template <>
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _4,_2>>,
Stride<Stride<_64,_1>,Stride<_16,_8>>>;
using BLayout = SM80_8x16_Row;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_16,_8,_32>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _4,_2, _2>>,
Stride<Stride<_64,_1>,Stride<_16,_8,_256>>>;
using BLayout = Layout<Shape <Shape < _4,_8>, Shape <_4, _2>>,
Stride<Stride<_32,_1>, Stride<_8,_128>>>;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s8 * u8 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x16_S32S8U8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x16_S32S8U8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = int8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S8U8S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u8 * s8 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x16_S32U8S8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x16_S32U8S8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = int8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U8S8S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u8 * u8 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN>
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x16_S32U8U8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN>
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x16_S32U8U8S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN>
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
{
using ElementDVal = int32_t;
using ElementAVal = uint8_t;
using ElementBVal = uint8_t;
using ElementCVal = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U8U8S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = b1 ^ b1 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
{
using ElementDVal = int32_t;
using ElementAVal = cute::uint1b_t;
using ElementBVal = cute::uint1b_t;
using ElementCVal = int32_t;
using Shape_MNK = Shape<_16,_8,_256>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <_32,Shape < _8, _4,_2, _2>>,
Stride<_64,Stride<_64,_16,_8,_2048>>>;
using BLayout = Layout<Shape <_32,Shape <_32, _2>>,
Stride<_32,Stride< _1,_1024>>>;
using CLayout = SM80_16x8_Row;
};
} // end namespace cute

View File

@ -0,0 +1,132 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
namespace cute {
///////////////////////////////////////////////////////////////////////////////
//////////////////////// fp64 = fp64 * fp64 + fp64 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using Shape_MNK = Shape<_16,_8,_4>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,_2>,
Stride<Stride<_16,_1>,_8>>;
using BLayout = Layout<Shape <Shape < _4,_8>,_1>,
Stride<Stride< _8,_1>,_0>>;
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
};
template <>
struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using Shape_MNK = Shape<_16,_8,_8>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
using BLayout = Layout<Shape <Shape < _4,_8>, _2>,
Stride<Stride< _8,_1>,_32>>;
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
};
template <>
struct MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
{
using ElementDVal = double;
using ElementAVal = double;
using ElementBVal = double;
using ElementCVal = double;
using Shape_MNK = Shape<_16,_8,_16>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _4>>,
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
using BLayout = Layout<Shape <Shape < _4,_8>, _4>,
Stride<Stride< _8,_1>,_32>>;
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
};
///////////////////////////////////////////////////////////////////////////////////
//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 ////////////////////////////
///////////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM90_16x8x4_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
};
template <>
struct MMA_Traits<SM90_16x8x8_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
};
template <>
struct MMA_Traits<SM90_16x8x16_C64C64C64C64_TN>
: MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
{
using ElementDVal = complex<double>;
using ElementAVal = complex<double>;
using ElementBVal = complex<double>;
using ElementCVal = complex<double>;
};
} // end namespace cute

File diff suppressed because it is too large Load Diff

121
include/cute/config.hpp Normal file
View File

@ -0,0 +1,121 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
# define CUTE_DEVICE __forceinline__ __device__
# define CUTE_HOST __forceinline__ __host__
#else
# define CUTE_HOST_DEVICE inline
# define CUTE_DEVICE inline
# define CUTE_HOST inline
#endif // CUTE_HOST_DEVICE, CUTE_DEVICE
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
# define CUTE_UNROLL #pragma unroll
# define CUTE_NO_UNROLL #pragma unroll 1
#else
# define CUTE_UNROLL
# define CUTE_NO_UNROLL
#endif // CUTE_UNROLL
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
# define CUTE_INLINE_CONSTANT static const __device__
#else
# define CUTE_INLINE_CONSTANT static constexpr
#endif
// Some versions of GCC < 11 have trouble deducing that a
// function with "auto" return type and all of its returns in an "if
// constexpr ... else" statement must actually return. Thus, GCC
// emits spurious "missing return statement" build warnings.
// Developers can suppress these warnings by using the
// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon.
// It's harmless to use the macro for other GCC versions or other
// compilers, but it has no effect.
#if ! defined(CUTE_GCC_UNREACHABLE)
# if defined(__GNUC__) && __GNUC__ < 11
// GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return
// statement" warnings without this little bit of help.
# define CUTE_GCC_UNREACHABLE __builtin_unreachable()
# else
# define CUTE_GCC_UNREACHABLE
# endif
#endif
//
// Assertion helpers
//
#include <cassert>
#define CUTE_STATIC_ASSERT static_assert
#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__)
#if defined(__CUDA_ARCH__)
# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory")
#else
# define CUTE_RUNTIME_ASSERT(x) assert(0 && x)
#endif
//
// IO
//
#include <cstdio>
#include <iostream>
#include <iomanip>
//
// Support
//
#include <cute/util/type_traits.hpp>
//
// Basic types
//
#include <cute/numeric/int.hpp>
#include <cute/numeric/real.hpp>
#include <cute/numeric/half.hpp>
#include <cute/numeric/float8.hpp>
#include <cute/numeric/bfloat.hpp>
#include <cute/numeric/tfloat.hpp>
#include <cute/numeric/complex.hpp>
//
// Debugging utilities
//
#include <cute/util/print.hpp>
#include <cute/util/debug.hpp>

View File

@ -0,0 +1,70 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/int.hpp>
#include <cute/numeric/math.hpp>
namespace cute
{
// Test if a pointer is aligned to N bytes
template <int N>
CUTE_HOST_DEVICE constexpr
bool
is_byte_aligned(void const* const ptr)
{
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check");
return (reinterpret_cast<uintptr_t>(ptr) & (N-1)) == 0;
}
#if defined(__CUDACC__)
# define CUTE_ALIGNAS(n) __align__(n)
#else
# define CUTE_ALIGNAS(n) alignas(n)
#endif
template <std::size_t Alignment>
struct aligned_struct {};
template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {};
template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {};
template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {};
template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {};
template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {};
template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {};
template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {};
template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {};
template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {};
} // end namespace cute

View File

@ -0,0 +1,282 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstddef>
#include <utility>
#include <cute/config.hpp>
namespace cute
{
template <class T, std::size_t N>
struct array
{
using value_type = T;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
CUTE_HOST_DEVICE constexpr
reference operator[](size_type pos)
{
return begin()[pos];
}
CUTE_HOST_DEVICE constexpr
const_reference operator[](size_type pos) const
{
return begin()[pos];
}
CUTE_HOST_DEVICE constexpr
reference front()
{
return *begin();
}
CUTE_HOST_DEVICE constexpr
const_reference front() const
{
return *begin();
}
CUTE_HOST_DEVICE constexpr
reference back()
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE constexpr
const_reference back() const
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE constexpr
T* data()
{
return __elems_;
}
CUTE_HOST_DEVICE constexpr
T const* data() const
{
return __elems_;
}
CUTE_HOST_DEVICE constexpr
iterator begin()
{
return data();
}
CUTE_HOST_DEVICE constexpr
const_iterator begin() const
{
return data();
}
CUTE_HOST_DEVICE constexpr
const_iterator cbegin()
{
return begin();
}
CUTE_HOST_DEVICE constexpr
const_iterator cbegin() const
{
return begin();
}
CUTE_HOST_DEVICE constexpr
iterator end()
{
return data() + size();
}
CUTE_HOST_DEVICE constexpr
const_iterator end() const
{
return data() + size();
}
CUTE_HOST_DEVICE constexpr
const_iterator cend()
{
return end();
}
CUTE_HOST_DEVICE constexpr
const_iterator cend() const
{
return end();
}
CUTE_HOST_DEVICE constexpr
bool empty() const
{
return size() == 0;
}
CUTE_HOST_DEVICE constexpr
size_type size() const
{
return N;
}
CUTE_HOST_DEVICE constexpr
size_type max_size() const
{
return size();
}
CUTE_HOST_DEVICE constexpr
void fill(const T& value)
{
for (auto& e : *this) {
e = value;
}
}
CUTE_HOST_DEVICE constexpr
void clear()
{
fill(T(0));
}
CUTE_HOST_DEVICE constexpr
void swap(array& other)
{
using std::swap;
for (size_type i = 0; i < size(); ++i) {
swap((*this)[i], other[i]);
}
}
value_type __elems_[N > 0 ? N : 1];
};
template<class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
bool operator==(array<T,N> const& lhs, array<T,N> const& rhs)
{
for (std::size_t i = 0; i < N; ++i) {
if (lhs[i] != rhs[i]) {
return false;
}
}
return true;
}
template <typename T, std::size_t N>
CUTE_HOST_DEVICE constexpr
void clear(array<T,N>& a)
{
a.fill(T(0));
}
template <typename T, std::size_t N>
CUTE_HOST_DEVICE constexpr
void fill(array<T,N>& a, T const& value)
{
a.fill(value);
}
template<class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
void swap(array<T,N>& a, array<T,N>& b)
{
a.swap(b);
}
} // end cute
//
// Specialize tuple-related functionality for cute::array
//
#include <tuple>
namespace cute
{
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T& get(array<T,N>& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T const& get(array<T,N> const& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T&& get(array<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
}
} // end namespace cute
namespace std
{
template <class T, std::size_t N>
struct tuple_size<cute::array<T,N>>
: std::integral_constant<std::size_t, N>
{};
template <std::size_t I, class T, std::size_t N>
struct tuple_element<I, cute::array<T,N>>
{
using type = T;
};
} // end std

View File

@ -0,0 +1,276 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/alignment.hpp>
#include <cute/numeric/int.hpp>
#include <cute/numeric/math.hpp>
namespace cute
{
template <typename T, std::size_t N, std::size_t Alignment = 16>
struct array_aligned
: public aligned_struct<Alignment>
{
/// Make sure the Alignment makes sense wrt the size of elements.
static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small");
/// Alignment must be a power of two
static_assert(has_single_bit(Alignment), "Alignment must be a power of two");
using value_type = T;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
CUTE_HOST_DEVICE constexpr
reference operator[](size_type pos)
{
return begin()[pos];
}
CUTE_HOST_DEVICE constexpr
const_reference operator[](size_type pos) const
{
return begin()[pos];
}
CUTE_HOST_DEVICE constexpr
reference front()
{
return *begin();
}
CUTE_HOST_DEVICE constexpr
const_reference front() const
{
return *begin();
}
CUTE_HOST_DEVICE constexpr
reference back()
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE constexpr
const_reference back() const
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE constexpr
T* data()
{
return reinterpret_cast<T*>(storage);
}
CUTE_HOST_DEVICE constexpr
T const* data() const
{
return reinterpret_cast<T const*>(storage);
}
CUTE_HOST_DEVICE constexpr
iterator begin()
{
return data();
}
CUTE_HOST_DEVICE constexpr
const_iterator begin() const
{
return data();
}
CUTE_HOST_DEVICE constexpr
const_iterator cbegin()
{
return begin();
}
CUTE_HOST_DEVICE constexpr
const_iterator cbegin() const
{
return begin();
}
CUTE_HOST_DEVICE constexpr
iterator end()
{
return data() + size();
}
CUTE_HOST_DEVICE constexpr
const_iterator end() const
{
return data() + size();
}
CUTE_HOST_DEVICE constexpr
const_iterator cend()
{
return end();
}
CUTE_HOST_DEVICE constexpr
const_iterator cend() const
{
return end();
}
CUTE_HOST_DEVICE constexpr
bool empty() const
{
return size() == 0;
}
CUTE_HOST_DEVICE constexpr
size_type size() const
{
return N;
}
CUTE_HOST_DEVICE constexpr
size_type max_size() const
{
return size();
}
CUTE_HOST_DEVICE constexpr
void fill(T const& value)
{
for (auto& e : *this) {
e = value;
}
}
CUTE_HOST_DEVICE constexpr
void clear()
{
fill(T(0));
}
// Not private, we want trivial type
//private:
/// Storage type to use for Elements
using StorageType = typename uint_byte<static_cast<int>(Alignment)>::type;
/// Ensure that there's enough storage for all elements
static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment");
/// Number of elements in the storage
static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType);
/// The storage.
StorageType storage[storageN > 0 ? storageN : 1];
};
//
// Operators
//
template <typename T, std::size_t N, std::size_t Alignment>
CUTE_HOST_DEVICE constexpr
void clear(array_aligned<T, N, Alignment>& a)
{
a.clear();
}
template <typename T, std::size_t N, std::size_t Alignment>
CUTE_HOST_DEVICE constexpr
void fill(array_aligned<T, N, Alignment>& a, T const& value)
{
a.fill(value);
}
} // end namespace cute
//
// Specialize tuple-related functionality for cute::array
//
#include <tuple>
namespace cute
{
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T& get(array_aligned<T,N>& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T const& get(array_aligned<T,N> const& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T&& get(array_aligned<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
}
} // end namespace cute
namespace std
{
template <class T, std::size_t N>
struct tuple_size<cute::array_aligned<T,N>>
: std::integral_constant<std::size_t, N>
{};
template <std::size_t I, class T, std::size_t N>
struct tuple_element<I, cute::array_aligned<T,N>>
{
using type = T;
};
} // end std

View File

@ -0,0 +1,613 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Statically sized array of elements that accommodates subbyte trivial types
in a packed storage.
*/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/int.hpp> // sizeof_bits
namespace cute
{
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically sized array for any data type
template <class T, std::size_t N>
class array_subbyte
{
public:
/// Number of total bits in the array
static constexpr int kSizeBits = sizeof_bits<T>::value * N;
/// Storage type
using Storage = typename std::conditional<
(kSizeBits % 32) == 0,
uint32_t,
typename std::conditional<
(kSizeBits % 16) == 0,
uint16_t,
uint8_t
>::type
>::type;
/// Number of logical elements per stored object
static constexpr int kElementsPerStoredItem = sizeof_bits<Storage>::value / sizeof_bits<T>::value;
/// Number of storage elements
static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
/// Bitmask for covering one item
static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits<T>::value) - 1);
//
// C++ standard members with reference and iterator types omitted
//
using value_type = T;
using pointer = value_type*;
using const_pointer = value_type const*;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
//
// References
//
/// Reference object inserts or extracts sub-byte items
class reference {
/// Pointer to storage element
Storage* ptr_;
/// Index into elements packed into Storage object
int idx_;
public:
/// Default ctor
CUTE_HOST_DEVICE constexpr
reference() : ptr_(nullptr), idx_(0) {}
/// Ctor
CUTE_HOST_DEVICE constexpr
reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {}
/// Assignment
CUTE_HOST_DEVICE constexpr
reference& operator=(T x) {
Storage item = (reinterpret_cast<Storage const&>(x) & bit_mask_);
Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits<T>::value)));
*ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits<T>::value)));
return *this;
}
CUTE_HOST_DEVICE constexpr
T get() const {
Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & bit_mask_);
return reinterpret_cast<T const&>(item);
}
/// Extract to type T -- disable if T == bool
template <class U = T, __CUTE_REQUIRES(not std::is_same<U,bool>::value)>
CUTE_HOST_DEVICE constexpr
operator T() const {
return get();
}
// Extract to bool -- potentially faster impl
CUTE_HOST_DEVICE constexpr
operator bool() const {
return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits<T>::value)));
}
/// Explicit cast to int
CUTE_HOST_DEVICE constexpr
explicit operator int() const {
return int(get());
}
/// Explicit cast to float
CUTE_HOST_DEVICE constexpr
explicit operator float() const {
return float(get());
}
};
/// Reference object extracts sub-byte items
class const_reference {
/// Pointer to storage element
Storage const* ptr_;
/// Index into elements packed into Storage object
int idx_;
public:
/// Default ctor
CUTE_HOST_DEVICE constexpr
const_reference(): ptr_(nullptr), idx_(0) { }
/// Ctor
CUTE_HOST_DEVICE constexpr
const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
CUTE_HOST_DEVICE constexpr
const T get() const {
Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & bit_mask_);
return reinterpret_cast<T const&>(item);
}
/// Extract to type T -- disable if T == bool
template <class U = T, __CUTE_REQUIRES(not std::is_same<U,bool>::value)>
CUTE_HOST_DEVICE constexpr
operator T() const {
return get();
}
// Extract to bool -- potentially faster impl
CUTE_HOST_DEVICE constexpr
operator bool() const {
return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits<T>::value)));
}
/// Explicit cast to int
CUTE_HOST_DEVICE constexpr
explicit operator int() const {
return int(get());
}
/// Explicit cast to float
CUTE_HOST_DEVICE constexpr
explicit operator float() const {
return float(get());
}
};
//
// Iterators
//
/// Bidirectional iterator over elements
class iterator {
/// Pointer to storage element
Storage* ptr_;
/// Index into elements packed into Storage object
int idx_;
public:
CUTE_HOST_DEVICE constexpr
iterator(): ptr_(nullptr), idx_(0) { }
CUTE_HOST_DEVICE constexpr
iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
CUTE_HOST_DEVICE constexpr
iterator& operator++() {
++idx_;
if (idx_ == kElementsPerStoredItem) {
++ptr_;
idx_ = 0;
}
return *this;
}
CUTE_HOST_DEVICE constexpr
iterator& operator--() {
if (idx_) {
--idx_;
} else {
--ptr_;
idx_ = kElementsPerStoredItem - 1;
}
return *this;
}
CUTE_HOST_DEVICE constexpr
iterator operator++(int) {
iterator ret(*this);
++(*this);
return ret;
}
CUTE_HOST_DEVICE constexpr
iterator operator--(int) {
iterator ret(*this);
--(*this);
return ret;
}
CUTE_HOST_DEVICE constexpr
iterator& operator+=(int k) {
idx_ += k;
ptr_ += idx_ / kElementsPerStoredItem;
idx_ = idx_ % kElementsPerStoredItem;
return *this;
}
CUTE_HOST_DEVICE constexpr
iterator operator+(int k) const {
return iterator(ptr_,idx_) += k;
}
CUTE_HOST_DEVICE constexpr
reference operator*() const {
return reference(ptr_, idx_);
}
CUTE_HOST_DEVICE constexpr
reference operator[](int k) const {
return *(*this + k);
}
CUTE_HOST_DEVICE constexpr
bool operator==(iterator const& other) const {
return ptr_ == other.ptr_ && idx_ == other.idx_;
}
CUTE_HOST_DEVICE constexpr
bool operator!=(iterator const& other) const {
return !(*this == other);
}
};
/// Bidirectional constant iterator over elements
class const_iterator {
/// Pointer to storage element
Storage const* ptr_;
/// Index into elements packed into Storage object
int idx_;
public:
CUTE_HOST_DEVICE constexpr
const_iterator(): ptr_(nullptr), idx_(0) { }
CUTE_HOST_DEVICE constexpr
const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
CUTE_HOST_DEVICE constexpr
const_iterator& operator++() {
++idx_;
if (idx_ == kElementsPerStoredItem) {
++ptr_;
idx_ = 0;
}
return *this;
}
CUTE_HOST_DEVICE constexpr
const_iterator& operator--() {
if (idx_) {
--idx_;
} else {
--ptr_;
idx_ = kElementsPerStoredItem - 1;
}
return *this;
}
CUTE_HOST_DEVICE constexpr
const_iterator operator++(int) {
iterator ret(*this);
++idx_;
if (idx_ == kElementsPerStoredItem) {
++ptr_;
idx_ = 0;
}
return ret;
}
CUTE_HOST_DEVICE constexpr
const_iterator operator--(int) {
iterator ret(*this);
if (idx_) {
--idx_;
} else {
--ptr_;
idx_ = kElementsPerStoredItem - 1;
}
return ret;
}
CUTE_HOST_DEVICE constexpr
const_iterator& operator+=(int k) {
idx_ += k;
ptr_ += idx_ / kElementsPerStoredItem;
idx_ = idx_ % kElementsPerStoredItem;
return *this;
}
CUTE_HOST_DEVICE constexpr
const_iterator operator+(int k) const {
return const_iterator(ptr_,idx_) += k;
}
CUTE_HOST_DEVICE constexpr
const_reference operator*() const {
return const_reference(ptr_, idx_);
}
CUTE_HOST_DEVICE constexpr
const_reference operator[](int k) const {
return *(*this + k);
}
CUTE_HOST_DEVICE constexpr
bool operator==(iterator const& other) const {
return ptr_ == other.ptr_ && idx_ == other.idx_;
}
CUTE_HOST_DEVICE constexpr
bool operator!=(iterator const& other) const {
return !(*this == other);
}
};
private:
/// Internal storage
Storage storage[kStorageElements];
public:
CUTE_HOST_DEVICE constexpr
array_subbyte() { }
CUTE_HOST_DEVICE constexpr
array_subbyte(array_subbyte const& x) {
CUTE_UNROLL
for (unsigned i = 0; i < kStorageElements; ++i) {
storage[i] = x.storage[i];
}
}
CUTE_HOST_DEVICE constexpr
size_type size() const {
return N;
}
CUTE_HOST_DEVICE constexpr
size_type max_size() const {
return N;
}
CUTE_HOST_DEVICE constexpr
bool empty() const {
return !N;
}
/// Efficient clear method
CUTE_HOST_DEVICE constexpr
void clear() {
CUTE_UNROLL
for (unsigned i = 0; i < kStorageElements; ++i) {
storage[i] = Storage(0);
}
}
// Efficient fill method
CUTE_HOST_DEVICE constexpr
void fill(T const& value) {
Storage item = (reinterpret_cast<Storage const&>(value) & bit_mask_);
// Reproduce the value over the bits of the storage item
CUTE_UNROLL
for (unsigned s = sizeof_bits<T>::value; s < sizeof_bits<Storage>::value; s *= 2) {
item |= item << s;
}
CUTE_UNROLL
for (unsigned i = 0; i < kStorageElements; ++i) {
storage[i] = item;
}
}
CUTE_HOST_DEVICE constexpr
reference at(size_type pos) {
return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
}
CUTE_HOST_DEVICE constexpr
const_reference at(size_type pos) const {
return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
}
CUTE_HOST_DEVICE constexpr
reference operator[](size_type pos) {
return at(pos);
}
CUTE_HOST_DEVICE constexpr
const_reference operator[](size_type pos) const {
return at(pos);
}
CUTE_HOST_DEVICE constexpr
reference front() {
return at(0);
}
CUTE_HOST_DEVICE constexpr
const_reference front() const {
return at(0);
}
CUTE_HOST_DEVICE constexpr
reference back() {
return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
}
CUTE_HOST_DEVICE constexpr
const_reference back() const {
return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
}
CUTE_HOST_DEVICE constexpr
pointer data() {
return reinterpret_cast<pointer>(storage);
}
CUTE_HOST_DEVICE constexpr
const_pointer data() const {
return reinterpret_cast<const_pointer>(storage);
}
CUTE_HOST_DEVICE constexpr
Storage* raw_data() {
return storage;
}
CUTE_HOST_DEVICE constexpr
Storage const* raw_data() const {
return storage;
}
CUTE_HOST_DEVICE constexpr
iterator begin() {
return iterator(storage);
}
CUTE_HOST_DEVICE constexpr
const_iterator begin() const {
return const_iterator(storage);
}
CUTE_HOST_DEVICE constexpr
const_iterator cbegin() const {
return begin();
}
CUTE_HOST_DEVICE constexpr
iterator end() {
return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem);
}
CUTE_HOST_DEVICE constexpr
const_iterator end() const {
return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem);
}
CUTE_HOST_DEVICE constexpr
const_iterator cend() const {
return end();
}
//
// Comparison operators
//
};
//
// Operators
//
template <class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
void clear(array_subbyte<T,N>& a)
{
a.clear();
}
template <class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
void fill(array_subbyte<T,N>& a, T const& value)
{
a.fill(value);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute
//
// Specialize tuple-related functionality for cute::array_subbyte
//
#include <tuple>
namespace cute
{
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T& get(array_subbyte<T,N>& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T const& get(array_subbyte<T,N> const& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template <std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T&& get(array_subbyte<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
}
} // end namespace cute
namespace std
{
template <class T, std::size_t N>
struct tuple_size<cute::array_subbyte<T,N>>
: std::integral_constant<std::size_t, N>
{};
template <std::size_t I, class T, std::size_t N>
struct tuple_element<I, cute::array_subbyte<T,N>>
{
using type = T;
};
} // end namespace std

View File

@ -0,0 +1,274 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstddef>
#include <utility>
#include <cute/config.hpp>
namespace cute
{
template <class T, std::size_t N>
struct array_view
{
using value_type = T;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
array_view(array<T,N>& a)
: __elems_(a.data()) {}
CUTE_HOST_DEVICE
reference operator[](size_type pos)
{
return begin()[pos];
}
CUTE_HOST_DEVICE
const_reference operator[](size_type pos) const
{
return begin()[pos];
}
CUTE_HOST_DEVICE
reference front()
{
return *begin();
}
CUTE_HOST_DEVICE
const_reference front() const
{
return *begin();
}
CUTE_HOST_DEVICE
reference back()
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE
const_reference back() const
{
// return *rbegin();
return operator[](N-1);
}
CUTE_HOST_DEVICE
T* data()
{
return __elems_;
}
CUTE_HOST_DEVICE
const T* data() const
{
return __elems_;
}
CUTE_HOST_DEVICE
iterator begin()
{
return data();
}
CUTE_HOST_DEVICE
const_iterator begin() const
{
return data();
}
CUTE_HOST_DEVICE
const_iterator cbegin()
{
return begin();
}
CUTE_HOST_DEVICE
const_iterator cbegin() const
{
return begin();
}
CUTE_HOST_DEVICE
iterator end()
{
return data() + size();
}
CUTE_HOST_DEVICE
const_iterator end() const
{
return data() + size();
}
CUTE_HOST_DEVICE
const_iterator cend()
{
return end();
}
CUTE_HOST_DEVICE
const_iterator cend() const
{
return end();
}
CUTE_HOST_DEVICE constexpr
bool empty() const
{
return size() == 0;
}
CUTE_HOST_DEVICE constexpr
size_type size() const
{
return N;
}
CUTE_HOST_DEVICE constexpr
size_type max_size() const
{
return size();
}
CUTE_HOST_DEVICE
void fill(const T& value)
{
for(auto& e : *this)
{
e = value;
}
}
CUTE_HOST_DEVICE
void swap(array_view& other)
{
using std::swap;
swap(__elems_, other.__elems_);
}
value_type* __elems_;
};
template<class T, std::size_t N>
CUTE_HOST_DEVICE
bool operator==(const array_view<T,N>& lhs, const array_view<T,N>& rhs)
{
for(std::size_t i = 0; i < N; ++i)
{
if(lhs[i] != rhs[i]) return false;
}
return true;
}
template <typename T, std::size_t N>
CUTE_HOST_DEVICE
void clear(array_view<T, N>& a)
{
a.fill(T(0));
}
template<class T, std::size_t N>
CUTE_HOST_DEVICE
void swap(array_view<T,N>& a, array_view<T,N>& b)
{
a.swap(b);
}
} // end cute
//
// Specialize tuple-related functionality for cute::array_view
//
#include <tuple>
namespace cute
{
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T&
get(array_view<T,N>& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
const T&
get(const array_view<T,N>& a)
{
static_assert(I < N, "Index out of range");
return a[I];
}
template<std::size_t I, class T, std::size_t N>
CUTE_HOST_DEVICE constexpr
T&&
get(array_view<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
}
} // end namespace cute
namespace std
{
template<class T, std::size_t N>
struct tuple_size<cute::array_view<T,N>>
: std::integral_constant<std::size_t, N>
{};
template<std::size_t I, class T, std::size_t N>
struct tuple_element<I, cute::array_view<T,N>>
{
using type = T;
};
} // end std

View File

@ -0,0 +1,131 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Portable bit field that supports byte and word straddling that can
be used in unions to bit-wise define parameters.
*/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/int.hpp> // uint_bit_t
namespace cute
{
class dummy_type {};
template <uint32_t BitStart, uint32_t NumBits, class OtherValueType = dummy_type>
struct bit_field
{
static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported.");
// value_type: Use the smallest value type that fits NumBits
static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 :
(NumBits <= 16) ? 16 :
(NumBits <= 32) ? 32 : 64;
using value_type = cute::uint_bit_t<value_type_bits>;
// storage_type: Use the smallest storage_type that avoids boundary crossing
static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 :
(BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 :
(BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64;
using storage_type = cute::uint_bit_t<storage_type_bits>;
static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same<OtherValueType,dummy_type>::value,
"sizeof(OtherValueType) must be same as sizeof(value_type).");
// Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits)
static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits;
// Index of storage value for BitStart
static constexpr uint32_t idx = BitStart / storage_type_bits;
// Bit of data_[idx] for BitStart
static constexpr uint32_t bit_lo = BitStart % storage_type_bits;
// Number of bits in data_[idx] used for NumBits if straddling, else 0
static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0;
// NumBits mask
static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1);
// NumBits mask for BitStart
static constexpr storage_type mask_lo = storage_type(mask) << bit_lo;
// NumBits mask for leftover bits in data_[idx+1] if straddling, else 0
static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0;
storage_type data_[N];
// Get value
CUTE_HOST_DEVICE constexpr
value_type get() const {
storage_type result = (data_[idx] & mask_lo) >> bit_lo;
if constexpr (bit_hi) {
result |= (data_[idx+1] & mask_hi) << bit_hi;
}
return static_cast<value_type>(result);
}
// Set value
CUTE_HOST_DEVICE constexpr
void set(value_type x) {
storage_type item = static_cast<storage_type>(x & mask);
data_[idx] = static_cast<storage_type>((data_[idx] & ~mask_lo) | (item << bit_lo));
if constexpr (bit_hi) {
data_[idx+1] = static_cast<storage_type>((data_[idx+1] & ~mask_hi) | (item >> bit_hi));
}
}
// Assign value
CUTE_HOST_DEVICE constexpr
bit_field& operator=(value_type x) {
set(x);
return *this;
}
// Cast to value
CUTE_HOST_DEVICE constexpr
operator value_type () const {
return get();
}
// Assign OtherValueType
CUTE_HOST_DEVICE constexpr
bit_field& operator=(OtherValueType x) {
return *this = *reinterpret_cast<value_type*>(&x);
}
// Cast to OtherValueType
CUTE_HOST_DEVICE constexpr
operator OtherValueType () const {
value_type x = get();
return *reinterpret_cast<OtherValueType*>(&x);
}
};
} // end namespace cute

View File

@ -0,0 +1,671 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <tuple>
#include <utility>
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp> // cute::true_type, cute::false_type
//#include <cute/container/array.hpp> // Advanced optimizations
#if 0
//
// Use of agency::tuple is functional, but is over-engineered for our purposes...
// This tends to result in slow compilation times and unintentionally propagated cvref types
//
#include <agency/tuple.hpp>
namespace cute
{
using agency::tuple;
using agency::make_tuple;
using agency::tuple_cat;
} // end namespace cute
#endif
// cute::tuple is like std::tuple, with two differences.
//
// 1. It works on both host and device.
// 2. Its template arguments must be semiregular types.
//
// Semiregular types are default constructible and copyable.
// They include "value types" like int or float,
// but do _not_ include references like int& or float&.
// (See std::tie for an example of a tuple of references.)
//
// This is simplified over the implementation in std:: and agency:: by ignoring much of
// the conversion SFINAE, special overloading, and avoiding cvref template types.
// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding
// construction calls, and ignoring any need for unique element addresses.
//
// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x.
namespace cute
{
namespace detail
{
// EBO stands for "empty base optimization."
// We use this technique to ensure that cute::tuple
// doesn't need to waste space storing any template arguments
// of cute::tuple that have no data (like integral_constant).
// Otherwise, cute::tuple would need to spend at least 1 byte
// for each of its template arguments.
//
// EBO always "holds" a single value of type T.
// N is like an array index that TupleBase uses
// to access the desired tuple element.
template <std::size_t N, class T, bool IsEmpty = std::is_empty<T>::value>
struct EBO;
// Specialization for types T that have no data;
// the "static tuple leaf." Valid T here include
// integral_constant<U, Value>, Int<Value>,
// and any other semiregular type
// for which std::is_empty_v<T> is true.
template <std::size_t N, class T>
struct EBO<N, T, true>
{
CUTE_HOST_DEVICE constexpr
EBO() {}
CUTE_HOST_DEVICE constexpr
EBO(T const&) {}
};
template <std::size_t N, class T>
CUTE_HOST_DEVICE constexpr T getv(EBO<N, T, true> const&)
{ return {}; }
// Specialization for types T that are not empty;
// the "dynamic tuple leaf." Valid T here include int,
// any other integral or floating-point type,
// or any semiregular type for which std::is_empty_v<T> is false.
template <std::size_t N, class T>
struct EBO<N, T, false>
{
CUTE_HOST_DEVICE constexpr
EBO() : t_{} {}
template <class U>
CUTE_HOST_DEVICE constexpr
EBO(U const& u) : t_{u} {}
T t_;
};
template <std::size_t N, class T>
CUTE_HOST_DEVICE constexpr T const& getv(EBO<N, T, false> const& x)
{ return x.t_; }
template <std::size_t N, class T>
CUTE_HOST_DEVICE constexpr T& getv(EBO<N, T, false>& x)
{ return x.t_; }
template <std::size_t N, class T>
CUTE_HOST_DEVICE constexpr T&& getv(EBO<N, T, false>&& x)
{ return static_cast<T&&>(x.t_); }
template <class IdxSeq, class... T>
struct TupleBase;
// Base class of cute::tuple.
// It inherits from EBO<i, t> for each (i, t) in (I..., T...).
// The actual storage (for nonempty t) lives in the base classes.
// index_sequence is a way to wrap up a sequence of zero or more
// compile-time integer values in a single type.
// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice,
// as the type alias TupleBase below indicates.
template <std::size_t... I, class... T>
struct TupleBase<std::index_sequence<I...>, T...>
: EBO<I,T>...
{
CUTE_HOST_DEVICE constexpr
TupleBase() {}
template <class... U>
CUTE_HOST_DEVICE constexpr explicit
TupleBase(U const&... u)
: EBO<I,T>(u)... {}
template <class... U>
CUTE_HOST_DEVICE constexpr
TupleBase(TupleBase<std::index_sequence<I...>, U...> const& u)
: EBO<I,T>(getv(static_cast<EBO<I,U> const&>(u)))... {}
};
} // end namespace detail
// make_index_sequence<K> returns index_sequence<0, 1, ..., K-1>.
template <class... T>
using TupleBase = detail::TupleBase<std::make_index_sequence<sizeof...(T)>, T...>;
// This is the actual cute::tuple class.
// The storage (if any) lives in TupleBase's EBO base classes.
template <class... T>
struct tuple : TupleBase<T...>
{
CUTE_HOST_DEVICE constexpr
tuple() {}
template <class... U>
CUTE_HOST_DEVICE constexpr
tuple(U const&... u) : TupleBase<T...>(u...) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
tuple(tuple<U...> const& u)
: TupleBase<T...>(static_cast<TupleBase<U...> const&>(u)) {}
};
//
// get for cute::tuple (just like std::get for std::tuple)
//
template <std::size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...> const& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <std::size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...>& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <std::size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...>&& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(static_cast<tuple<T...>&&>(t));
}
//
// Custom is_tuple trait simply checks the existence of std::tuple_size
// and assumes std::get<I>(.), std::tuple_element<I,.>
//
namespace detail {
template <class T>
std::integral_constant<bool, std::tuple_size<T>::value >= 0> has_tuple_size(int);
template <class T>
std::false_type has_tuple_size(...);
} // end namespace detail
template <class T>
struct is_tuple : decltype(detail::has_tuple_size<T>(0)) {};
//
// make_tuple (value-based implementation)
//
template <class... T>
CUTE_HOST_DEVICE constexpr
tuple<T...>
make_tuple(T const&... t)
{
return {t...};
}
//
// tuple_cat concatenates multiple cute::tuple into a single cute::tuple,
// just like std::tuple_cat for std::tuple.
//
#if 0
// Original implementation
namespace detail {
template <class T0, class T1,
std::size_t... I0, std::size_t... I1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1,
std::index_sequence<I0...>, std::index_sequence<I1...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
}
} // end namespace detail
CUTE_HOST_DEVICE constexpr
tuple<>
tuple_cat()
{
return {};
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE constexpr
Tuple const&
tuple_cat(Tuple const& t)
{
return t;
}
template <class T0, class T1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1)
{
return detail::tuple_cat(t0, t1,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{});
}
template <class T0, class T1, class T2, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts)
{
return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...);
}
#endif
#if 1
// Extended implementation
namespace detail {
template <class T0, class T1,
std::size_t... I0, std::size_t... I1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1,
std::index_sequence<I0...>, std::index_sequence<I1...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
}
template <class T0, class T1, class T2,
std::size_t... I0, std::size_t... I1, std::size_t... I2>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2,
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)...);
}
template <class T0, class T1, class T2, class T3,
std::size_t... I0, std::size_t... I1, std::size_t... I2, std::size_t... I3>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3,
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>, std::index_sequence<I3...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)...);
}
template <class T0, class T1, class T2, class T3, class T4,
std::size_t... I0, std::size_t... I1, std::size_t... I2, std::size_t... I3, std::size_t... I4>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4,
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>, std::index_sequence<I3...>, std::index_sequence<I4...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)..., get<I4>(t4)...);
}
} // end namespace detail
CUTE_HOST_DEVICE constexpr
tuple<>
tuple_cat()
{
return {};
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE constexpr
Tuple const&
tuple_cat(Tuple const& t)
{
return t;
}
template <class T0, class T1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1)
{
return detail::tuple_cat(t0, t1,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{});
}
template <class T0, class T1, class T2>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2)
{
return detail::tuple_cat(t0, t1, t2,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{},
std::make_index_sequence<std::tuple_size<T2>::value>{});
}
template <class T0, class T1, class T2, class T3>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3)
{
return detail::tuple_cat(t0, t1, t2, t3,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{},
std::make_index_sequence<std::tuple_size<T2>::value>{},
std::make_index_sequence<std::tuple_size<T3>::value>{});
}
template <class T0, class T1, class T2, class T3, class T4>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4)
{
return detail::tuple_cat(t0, t1, t2, t3, t4,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{},
std::make_index_sequence<std::tuple_size<T2>::value>{},
std::make_index_sequence<std::tuple_size<T3>::value>{},
std::make_index_sequence<std::tuple_size<T4>::value>{});
}
template <class T0, class T1, class T2, class T3, class T4, class T5, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts)
{
return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...);
}
#endif
#if 0
// Outer-Inner indexing trick to concat all tuples at once
namespace detail {
template <std::size_t... Ns>
struct tuple_cat_helper
{
static constexpr cute::array<std::size_t,sizeof...(Ns)> ns = {Ns...};
static constexpr std::size_t total_size() {
std::size_t sum = 0;
for (std::size_t n : ns) sum += n;
return sum;
}
static constexpr std::size_t total_size_ = total_size();
static constexpr auto values() {
cute::array<std::size_t[2],total_size_> outer_inner = {};
std::size_t idx = 0;
for (std::size_t i = 0; i < ns.size(); ++i) {
for (std::size_t j = 0; j < ns[i]; ++j, ++idx) {
outer_inner[idx][0] = i;
outer_inner[idx][1] = j;
}
}
return outer_inner;
}
static constexpr auto outer_inner_ = values();
using total_sequence = std::make_index_sequence<total_size_>;
};
template <class Helper, class Tuple, std::size_t... I>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(Tuple const& t, std::index_sequence<I...>)
{
return cute::make_tuple(get<Helper::outer_inner_[I][1]>(get<Helper::outer_inner_[I][0]>(t))...);
}
template <class T0, class T1,
std::size_t... I0, std::size_t... I1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1,
std::index_sequence<I0...>, std::index_sequence<I1...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
}
} // end namespace detail
CUTE_HOST_DEVICE constexpr
tuple<>
tuple_cat()
{
return {};
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE constexpr
Tuple const&
tuple_cat(Tuple const& t)
{
return t;
}
template <class T0, class T1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1)
{
return detail::tuple_cat(t0, t1,
std::make_index_sequence<std::tuple_size<T0>::value>{},
std::make_index_sequence<std::tuple_size<T1>::value>{});
}
template <class... Tuples>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(Tuples const&... ts)
{
using Helper = detail::tuple_cat_helper<std::tuple_size<Tuples>::value...>;
return detail::tuple_cat<Helper>(make_tuple(ts...), typename Helper::total_sequence{});
}
#endif
//
// Equality operators
//
namespace detail {
template <std::size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
equal_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == std::tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted
} else if constexpr (I == std::tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted
} else {
return (get<I>(a) == get<I>(b)) && equal_impl<I+1>(a,b);
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator==(TupleT const& t, TupleU const& u)
{
return detail::equal_impl<0>(t, u);
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator==(TupleT const& t, TupleU const& u)
{
return cute::false_type{};
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator!=(TupleT const& t, TupleU const& u)
{
return !(t == u);
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator!=(TupleT const& t, TupleU const& u)
{
return cute::true_type{};
}
//
// Comparison operators
//
//
// There are many ways to compare tuple of elements and because CuTe is built
// on parameterizing layouts of coordinates, some comparisons are appropriate
// only in certain cases.
// -- lexicographical comparison [reverse, reflected, revref]
// -- colexicographical comparison [reverse, reflected, revref]
// -- element-wise comparison [any,all]
// This can be very confusing. To avoid errors in selecting the appropriate
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple.
//
// That said, see int_tuple for more explicitly named common comparison ops.
//
//
// Shortcuts
//
//using std::get;
using std::tuple_size;
using std::tuple_element;
using std::tuple_element_t;
//
// Display utilities
//
namespace detail {
template <class Tuple, std::size_t... Is>
CUTE_HOST_DEVICE void print_tuple(Tuple const& t,
std::index_sequence<Is...>, char s = '(', char e = ')')
{
using eat = int[];
using cute::print;
(void) eat {(print(s), 0),
(print(Is == 0 ? "" : ","), print(get<Is>(t)), 0)...,
(print(e), 0)};
}
template <class Tuple, std::size_t... Is>
CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t,
std::index_sequence<Is...>, char s = '(', char e = ')')
{
using eat = int[];
(void) eat {(void(os << s), 0),
(void(os << (Is == 0 ? "" : ",") << get<Is>(t)), 0)...,
(void(os << e), 0)};
return os;
}
} // end namespace detail
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE void print(Tuple const& t)
{
return detail::print_tuple(t, std::make_index_sequence<std::tuple_size<Tuple>::value>{});
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t)
{
return detail::print_tuple_os(os, t, std::make_index_sequence<std::tuple_size<Tuple>::value>{});
}
} // end namespace cute
//
// std:: compatability
//
namespace std
{
template <class... T>
struct tuple_size<cute::tuple<T...>>
: std::integral_constant<std::size_t, sizeof...(T)>
{};
template <std::size_t I, class... T>
struct tuple_element<I, cute::tuple<T...>>
: std::tuple_element<I, std::tuple<T...>>
{};
} // end std

View File

@ -0,0 +1,84 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
namespace cute
{
template <class T>
struct type_c {
using type = T;
};
template <class... T>
struct type_list {};
} // end namespace cute
//
// Specialize tuple-related functionality for cute::type_list
//
#include <tuple>
#include <cute/container/tuple.hpp>
namespace cute
{
template <int I, class... T>
CUTE_HOST_DEVICE constexpr
std::tuple_element_t<I, type_list<T...>>
get(type_list<T...>&) noexcept {
return {};
}
template <int I, class... T>
CUTE_HOST_DEVICE constexpr
std::tuple_element_t<I, type_list<T...>>
get(type_list<T...> const& t) noexcept {
return {};
}
} // end namespace cute
namespace std
{
template <class... T>
struct tuple_size<cute::type_list<T...>>
: std::integral_constant<std::size_t, sizeof...(T)>
{};
template <std::size_t I, class... T>
struct tuple_element<I, cute::type_list<T...>>
: cute::type_c<typename std::tuple_element<I, std::tuple<T...>>::type>
{};
} // end namespace std

827
include/cute/int_tuple.hpp Normal file
View File

@ -0,0 +1,827 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/container/array.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
template <class... Ts>
using IntTuple = cute::tuple<Ts...>;
// Construct an IntTuple with all value-elements
template <class... Ts>
CUTE_HOST_DEVICE constexpr
IntTuple<Ts...>
make_int_tuple(Ts const&... t)
{
return {t...};
}
/** if rank(int) == 1, then get<0>(int) should work too
*/
template <std::size_t I, class T, __CUTE_REQUIRES(is_integral<remove_cvref_t<T>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(T&& t) noexcept
{
static_assert(I == 0, "Index out of range");
return static_cast<T&&>(t);
}
/** Custom recursive get for anything that implements get<I>(.)
*/
template <std::size_t I0, std::size_t I1, std::size_t... Is, class Tuple>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(Tuple&& t) noexcept
{
return get<I1,Is...>(get<I0>(static_cast<Tuple&&>(t)));
}
//
// rank
//
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
rank(IntTuple const& t)
{
if constexpr (sizeof...(Is) == 0) {
if constexpr (is_tuple<IntTuple>::value) {
return Int<tuple_size<IntTuple>::value>{};
} else {
return Int<1>{};
}
} else {
return rank(get<Is...>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class IntTuple>
using rank_t = decltype(rank(std::declval<IntTuple>()));
template <class IntTuple>
static constexpr int rank_v = rank_t<IntTuple>::value;
//
// shape
//
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
shape(IntTuple const& s)
{
if constexpr (is_tuple<IntTuple>::value) {
return transform(s, [](auto const& a) { return shape(a); });
} else {
return s;
}
CUTE_GCC_UNREACHABLE;
}
template <int I, int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
shape(IntTuple const& s)
{
if constexpr (is_tuple<IntTuple>::value) {
return shape<Is...>(get<I>(s));
} else {
return get<I,Is...>(shape(s));
}
CUTE_GCC_UNREACHABLE;
}
//
// max
//
template <class T0, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
max(T0 const& t0, Ts const&... ts)
{
if constexpr (is_tuple<T0>::value) {
return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...);
} else if constexpr (sizeof...(Ts) == 0) {
return t0;
} else {
return cute::max(t0, cute::max(ts...));
}
CUTE_GCC_UNREACHABLE;
}
//
// min
//
template <class T0, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
min(T0 const& t0, Ts const&... ts)
{
if constexpr (is_tuple<T0>::value) {
return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...);
} else if constexpr (sizeof...(Ts) == 0) {
return t0;
} else {
return cute::min(t0, cute::min(ts...));
}
CUTE_GCC_UNREACHABLE;
}
//
// depth
//
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
depth(IntTuple const& t)
{
if constexpr (sizeof...(Is) == 0) {
if constexpr (is_tuple<IntTuple>::value) {
return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); });
} else {
return Int<0>{};
}
} else {
return depth(get<Is...>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class Tuple>
using depth_t = decltype(depth(std::declval<Tuple>()));
template <class Tuple>
static constexpr int depth_v = depth_t<Tuple>::value;
//
// product
//
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
product(IntTuple const& a)
{
if constexpr (is_tuple<IntTuple>::value) {
return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); });
} else {
return a;
}
CUTE_GCC_UNREACHABLE;
}
// Product of a subrange
template <int B, int E, class Tuple>
CUTE_HOST_DEVICE constexpr
auto
product(Tuple const& a)
{
return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range<B,E>{});
}
template <class Tuple>
CUTE_HOST_DEVICE constexpr
auto
product_each(Tuple const& t)
{
return transform(t, [](auto const& x) { return product(x); });
}
// Return the product of elements in a mode
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
size(IntTuple const& a)
{
if constexpr (sizeof...(Is) == 0) {
return product(a);
} else {
return product(get<Is...>(a));
}
CUTE_GCC_UNREACHABLE;
}
template <class IntTuple>
static constexpr int size_v = decltype(size(std::declval<IntTuple>()))::value;
//
// sum
//
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
sum(IntTuple const& a)
{
if constexpr (is_tuple<IntTuple>::value) {
return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); });
} else {
return a;
}
CUTE_GCC_UNREACHABLE;
}
//
// inner_product
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
inner_product(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); },
[](auto const&... v) { return (Int<0>{} + ... + v); });
} else {
return a * b;
}
CUTE_GCC_UNREACHABLE;
}
//
// ceil_div
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implictly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
} else {
return (a + b - Int<1>{}) / b;
}
CUTE_GCC_UNREACHABLE;
}
/** Division for Shapes
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
shape_div(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value) {
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); });
} else { // tuple int
auto const [result, rest] = fold(a, make_tuple(make_tuple(), b),
[] (auto const& init, auto const& ai) {
return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai));
});
return result;
}
} else {
if constexpr (is_tuple<IntTupleB>::value) { // int tuple
return shape_div(a, product(b));
} else { // int int
//assert(a % b == 0 || b % a == 0);
return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero
}
}
CUTE_GCC_UNREACHABLE;
}
/** Division for Shapes that are static constants
* @pre t % u == 0 || u % t == 0
* @result if t % u == 0, then t / u
* if u % t == 0, then signum(t) * signum(u)
*/
template <class T, T t, class U, U u>
CUTE_HOST_DEVICE constexpr
constant<decltype(shape_div(t,u)), shape_div(t,u)>
shape_div(constant<T, t> const&, constant<U, u> const&)
{
static_assert(t % u == 0 || u % t == 0, "Static shape_div failure");
return {};
}
/** Return a tuple the same profile as A scaled by corresponding elements in B
*/
template <class A, class B>
CUTE_HOST_DEVICE constexpr
auto
elem_scale(A const& a, B const& b)
{
if constexpr (is_tuple<A>::value) {
return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); });
} else {
return a * product(b);
}
CUTE_GCC_UNREACHABLE;
}
/** Test if two IntTuple have the same profile (hierarchical rank division)
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
congruent(IntTupleA const& a, IntTupleB const& b)
{
return bool_constant<std::is_same<decltype(repeat_like(shape(a),_0{})),
decltype(repeat_like(shape(b),_0{}))>::value>{};
}
template <class A, class B>
using is_congruent = decltype(congruent(std::declval<A>(), std::declval<B>()));
/** Test if Shape B is compatible with Shape A:
* Any coordinate into A can also be used as a coordinate into B
* A <= B is a partially ordered set of factored shapes
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
compatible(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) {
return false_type{};
} else {
return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); },
[](auto const&... z) { return (true_type{} && ... && z); });
}
} else if constexpr (is_integral<IntTupleA>::value) {
return a == size(b);
} else if constexpr (is_integral<IntTupleB>::value) {
return false_type{};
} else {
return compatible(shape(a), shape(b));
}
CUTE_GCC_UNREACHABLE;
}
template <class A, class B>
using is_compatible = decltype(compatible(std::declval<A>(), std::declval<B>()));
/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1>
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value) {
return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); });
} else if constexpr (is_constant<0, IntTupleA>::value) {
return Int<1>{};
} else {
return b;
}
CUTE_GCC_UNREACHABLE;
}
template <class Tuple>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tuple const& t)
{
return filter_zeros(t, t);
}
//
// Converters and constructors with arrays and params
//
/** Make an IntTuple of rank N from an Indexable array.
* Access elements up to a dynamic index n, then use init (requires compatible types)
* Consider cute::take<B,E> if all indexing is known to be valid
* \code
* std::vector<int> a = {6,3,4};
* auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0)
* \endcode
*/
template <int N, class Indexable, class T>
CUTE_HOST_DEVICE constexpr
auto
make_int_tuple(Indexable const& t, int n, T const& init)
{
static_assert(N > 0);
if constexpr (N == 1) {
return 0 < n ? t[0] : init;
} else {
return transform(make_seq<N>{}, [&](auto i) { return i < n ? t[i] : init; });
}
CUTE_GCC_UNREACHABLE;
}
/** Fill the dynamic values of a Tuple with values from another Tuple
* \code
* auto params = make_int_tuple(6,3,4);
* cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>> result;
* fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2)
* \endcode
*/
template <class Tuple, class TupleV>
CUTE_HOST_DEVICE constexpr
auto
fill_int_tuple_from(Tuple& result, TupleV const& vals)
{
return fold(result, vals, [](auto const& init, auto&& r) {
if constexpr (is_static<remove_cvref_t<decltype(r)>>::value) { // Skip static elements of result
return init;
} else if constexpr (is_tuple<remove_cvref_t<decltype(r)>>::value) { // Recurse into tuples
return fill_int_tuple_from(r, init);
} else { // Assign and consume arg
static_assert(tuple_size<remove_cvref_t<decltype(init)>>::value > 0, "Not enough values to fill with!");
r = get<0>(init);
return remove<0>(init);
}
CUTE_GCC_UNREACHABLE;
});
}
/** Make a "Tuple" by filling in the dynamic values in order from the arguments
* \code
* using result_t = cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>>;
* auto result = make_int_tuple_from<result_t>(6,3,4); // (_1,(6,3,_3),4,_2)
* \endcode
*/
template <class Tuple, class... Ts>
CUTE_HOST_DEVICE constexpr
Tuple
make_int_tuple_from(Ts const&... ts)
{
Tuple result = Tuple{};
fill_int_tuple_from(result, make_tuple(ts...));
return result;
}
/** Convert a tuple to a flat homogeneous array of type T
* \code
* auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{});
* cute::array<uint64_t,6> result = to_array<uint64_t>(tup); // [1,6,3,3,4,2]
* \endcode
*/
template <class T = int64_t, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
to_array(IntTuple const& t)
{
auto flat_t = flatten_to_tuple(t);
constexpr int N = tuple_size<decltype(flat_t)>::value;
cute::array<T,N> result;
for_each(make_seq<N>{}, [&] (auto i) { result[i] = get<i>(flat_t); });
return result;
}
//
// Comparison operators
//
//
// There are many ways to compare tuple of elements and because CuTe is built
// on parameterizing layouts of coordinates, some comparisons are appropriate
// only in certain cases.
// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout
// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout
// -- element-wise comparison [any,all] :
// This can be very confusing. To avoid errors in selecting the appropriate
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple.
//
// When actually desiring to order coordinates, the user should map them to
// their indices within the Layout they came from:
// e.g. layoutX(coordA) < layoutX(coordB)
// That said, we implement the three most common ways to compare tuples below.
// These are implemented with slighly more explicit names than op<.
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less(IntTupleA const& a, IntTupleB const& b);
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less(IntTupleA const& a, IntTupleB const& b);
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less(IntTupleA const& a, IntTupleB const& b);
namespace detail {
template <std::size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleB is exhausted
} else if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
} else {
return lex_less(get<I>(a), get<I>(b)) || (get<I>(a) == get<I>(b) && lex_less_impl<I+1>(a,b));
}
CUTE_GCC_UNREACHABLE;
}
template <std::size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleB is exhausted
} else if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
} else {
constexpr std::size_t A = tuple_size<TupleA>::value - 1 - I;
constexpr std::size_t B = tuple_size<TupleB>::value - 1 - I;
return colex_less(get<A>(a), get<B>(b)) || (get<A>(a) == get<B>(b) && colex_less_impl<I+1>(a,b));
}
CUTE_GCC_UNREACHABLE;
}
template <std::size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted
} else if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted
} else {
return elem_less(get<I>(a), get<I>(b)) && elem_less_impl<I+1>(a,b);
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// Lexicographical comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::lex_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_leq(T const& t, U const& u) {
return !lex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_gtr(T const& t, U const& u) {
return lex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_geq(T const& t, U const& u) {
return !lex_less(t, u);
}
// Colexicographical comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::colex_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_leq(T const& t, U const& u) {
return !colex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_gtr(T const& t, U const& u) {
return colex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_geq(T const& t, U const& u) {
return !colex_less(t, u);
}
// Elementwise [all] comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::elem_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_leq(T const& t, U const& u) {
return !elem_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_gtr(T const& t, U const& u) {
return elem_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_geq(T const& t, U const& u) {
return !elem_less(t, u);
}
/** Increment a (dynamic) coord lexicographically within a shape
* \code
* auto shape = make_shape(1,2,make_shape(2,3),3);
*
* int i = 0;
* for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) {
* std::cout << i++ << ": " << coord << std::endl;
* }
* assert(i == size(shape));
* \endcode
*/
template <class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape);
namespace detail {
template <class Coord, class Shape, int I0, int... Is>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape, seq<I0,Is...>)
{
cute::increment(get<I0>(coord), get<I0>(shape));
if constexpr (sizeof...(Is) != 0) {
if (back(get<I0>(coord)) == back(get<I0>(shape))) {
back(get<I0>(coord)) = 0;
increment(coord, shape, seq<Is...>{});
}
}
}
} // end namespace detail
template <class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape)
{
if constexpr (is_integral<Coord>::value && is_integral<Shape>::value) {
++coord;
} else if constexpr (is_tuple<Coord>::value && is_tuple<Shape>::value) {
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched ranks");
detail::increment(coord, shape, tuple_seq<Coord>{});
} else {
static_assert(sizeof(Coord) == 0, "Invalid parameters");
}
}
struct ForwardCoordIteratorSentinal
{};
// A forward iterator for a coordinate that starts from zero and goes to shape
template <class Coord, class Shape>
struct ForwardCoordIterator
{
static_assert(is_congruent<Coord, Shape>::value);
CUTE_HOST_DEVICE constexpr
Coord const& operator*() const { return coord; }
CUTE_HOST_DEVICE constexpr
ForwardCoordIterator& operator++() { increment(coord, shape); return *this; }
// Sentinal for the end of the implied range
CUTE_HOST_DEVICE constexpr
bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); }
CUTE_HOST_DEVICE constexpr
bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); }
CUTE_HOST_DEVICE constexpr
bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); }
// NOTE: These are expensive, avoid use
CUTE_HOST_DEVICE constexpr
bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); }
CUTE_HOST_DEVICE constexpr
bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; }
CUTE_HOST_DEVICE constexpr
bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; }
Coord coord;
Shape const& shape;
};
// A forward iterator for a coordinate that starts from zero
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_coord_iterator(Shape const& shape)
{
auto coord = repeat_like(shape, int(0));
return ForwardCoordIterator<decltype(coord),Shape>{coord,shape};
}
} // end namespace cute

1638
include/cute/layout.hpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,388 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
namespace cute
{
template <class... T>
struct ArithmeticTuple : tuple<T...>
{
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(ArithmeticTuple<U...> const& u)
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(tuple<U...> const& u)
: tuple<T...>(u) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(U const&... u)
: tuple<T...>(u...) {}
};
template <class... T>
struct is_tuple<ArithmeticTuple<T...>> : true_type {};
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
make_arithmetic_tuple(T const&... t) {
return ArithmeticTuple<T...>(t...);
}
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(tuple<T...> const& t) {
return ArithmeticTuple<T...>(t);
}
//
// Numeric operators
//
// Addition
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
//
// Special cases
//
template <class T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(constant<T,0>, ArithmeticTuple<U...> const& u) {
return u;
}
template <class... T, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, constant<U,0>) {
return t;
}
//
// ArithmeticTupleIterator
//
template <class ArithTuple>
struct ArithmeticTupleIterator
{
ArithTuple coord_;
CUTE_HOST_DEVICE constexpr
ArithmeticTupleIterator() : coord_() {}
CUTE_HOST_DEVICE constexpr
ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {}
CUTE_HOST_DEVICE constexpr
ArithTuple const& operator*() const { return coord_; }
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator+(Coord const& c) const {
return ArithmeticTupleIterator<decltype(coord_ + c)>(coord_ + c);
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator[](Coord const& c) const { return *(*this + c); }
};
template <class ArithTuple>
CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter) {
printf("ArithTuple"); print(iter.coord_);
}
//
// ArithmeticTuple "basis" elements
//
// Abstract value:
// A ScaledBasis<T,N> is a (at least) rank-N0 ArithmeticTuple:
// (_0,_0,...,T,_0,...)
template <class T, int N>
struct ScaledBasis : private tuple<T>
{
CUTE_HOST_DEVICE constexpr
ScaledBasis(T const& t = {}) : tuple<T>(t) {}
CUTE_HOST_DEVICE constexpr
decltype(auto) value() { return get<0>(static_cast<tuple<T> &>(*this)); }
CUTE_HOST_DEVICE constexpr
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
CUTE_HOST_DEVICE static constexpr
auto mode() { return Int<N>{}; }
};
template <class T>
struct is_scaled_basis : false_type {};
template <class T, int N>
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
template <class T, int N>
struct is_integral<ScaledBasis<T,N>> : true_type {};
template <class T>
CUTE_HOST_DEVICE constexpr auto
basis_value(T const& e) {
return e;
}
template <class T, int N>
CUTE_HOST_DEVICE constexpr auto
basis_value(ScaledBasis<T,N> const& e) {
return basis_value(e.value());
}
namespace detail {
template <int... Ns>
struct Basis;
template <>
struct Basis<> {
using type = Int<1>;
};
template <int N, int... Ns>
struct Basis<N,Ns...> {
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
};
} // end namespace detail
template <int... N>
using E = typename detail::Basis<N...>::type;
namespace detail {
template <class T, int... I, int... J>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(T const& t, seq<I...>, seq<J...>) {
return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...);
}
template <class... T, int... I, int... J>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t, seq<I...>, seq<J...>) {
return make_arithmetic_tuple(get<I>(t)..., (void(J),Int<0>{})...);
}
} // end namespace detail
// Turn a ScaledBases<T,N> into a rank-M ArithmeticTuple
// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0)
template <int M, class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
static_assert(M > N, "Mismatched ranks");
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
}
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
template <int M, class... T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
static_assert(M >= sizeof...(T), "Mismatched ranks");
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
}
// Return...
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_basis_like(Shape const& shape)
{
if constexpr (is_integral<Shape>::value) {
return Int<1>{};
} else {
// Generate bases for each rank of shape
return transform(tuple_seq<Shape>{}, [&](auto I) {
// Generate bases for each rank of shape_i and add an i on front
constexpr int i = decltype(I)::value; // NOTE: nvcc workaround
return transform_leaf(make_basis_like(get<i>(shape)), [&](auto e) { return ScaledBasis<decltype(e),i>{}; });
});
}
CUTE_GCC_UNREACHABLE;
}
// Equality
template <class T, int N, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(ScaledBasis<T,N>, Int<M>) {
return false_type{};
}
template <int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(Int<N>, ScaledBasis<U,M>) {
return false_type{};
}
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return bool_constant<M == N>{} && t.value() == u.value();
}
// Multiplication
template <class A, int N, class T,
__CUTE_REQUIRES(cute::is_integral<A>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator*(A const& a, ScaledBasis<T,N> const& e) {
return ScaledBasis<decltype(a*e.value()),N>{a*e.value()};
}
template <int N, class T, class B,
__CUTE_REQUIRES(cute::is_integral<B>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator*(ScaledBasis<T,N> const& e, B const& b) {
return ScaledBasis<decltype(e.value()*b),N>{e.value()*b};
}
// Addition
template <int N, class T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(N+1, int(sizeof...(U)));
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <class... T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), M+1);
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <int N, class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
constexpr int R = cute::max(N+1,M+1);
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
}
template <class T, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(constant<T,0>, ScaledBasis<U,M> const& u) {
return u;
}
template <class T, int N, class U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
return t;
}
//
// Display utilities
//
template <class T, int N>
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e) {
printf("%d:", N); print(e.value());
}
template <class T, int N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e) {
return os << N << ":" << e.value();
}
} // end namespace cute
namespace std
{
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: std::integral_constant<std::size_t, sizeof...(T)>
{};
template <std::size_t I, class... T>
struct tuple_element<I, cute::ArithmeticTuple<T...>>
: std::tuple_element<I, std::tuple<T...>>
{};
} // end namespace std

View File

@ -0,0 +1,51 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
namespace cute {
using cutlass::bfloat16_t;
//
// Display utilities
//
CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v)
{
return os << float(v);
}
} // end namespace cute

View File

@ -0,0 +1,163 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstdint>
//#if defined(__CUDA_ARCH__)
//# include <cuda/std/complex>
//#else
//# include <complex>
//#endif
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
// on line 656 of thrust/detail/type_traits.h.
// These pragmas suppress the warnings.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
#include <thrust/complex.h>
#pragma GCC diagnostic pop
#include <cute/config.hpp>
namespace cute
{
//#if defined(__CUDA_ARCH__)
//template <class T>
//using complex = cuda::std::complex<T>;
//#else
//template <class T>
//using complex = std::complex<T>;
//#endif
//template <class T>
//using complex = thrust::complex<T>;
using thrust::complex;
template <class T>
CUTE_HOST_DEVICE
T real(complex<T> const& z) {
return z.real();
}
template <class T>
CUTE_HOST_DEVICE
T imag(complex<T> const& z) {
return z.imag();
}
template <class T>
CUTE_HOST_DEVICE
complex<T> conj(complex<T> const& z) {
return complex<T>(real(z), -imag(z));
}
// cute::conj forwards scalars
template <class T>
CUTE_HOST_DEVICE
T conj(T z) {
return z;
}
//CUTE_HOST_DEVICE constexpr
//float conj(float z) { return z; }
//CUTE_HOST_DEVICE constexpr
//double conj(double z) { return z; }
/// Fused multiply-add for complex numbers
template <class T>
CUTE_HOST_DEVICE constexpr
void
fma(complex<T> & d,
complex<T> const& a,
complex<T> const& b,
complex<T> const& c)
{
d.real(c.real() + a.real() * b.real());
d.imag(c.imag() + a.real() * b.imag());
d.real(d.real() - a.imag() * b.imag());
d.imag(d.imag() + a.imag() * b.real());
}
/// Fused multiply-add for triplets
template <class T>
CUTE_HOST_DEVICE constexpr
void
fma(complex<T> const& a,
complex<T> const& b,
complex<T> & c)
{
return fma(c, a, b, c);
}
/// Used to determine the real-valued underlying type of a numeric type T
template <class T>
struct RealType {
using Type = T;
};
/// Partial specialization for complex-valued type
template <class T>
struct RealType<complex<T>> {
using Type = T;
};
//////////////////////////////////////////////////////////////////////////////////////////////////
template <class T>
struct is_complex {
static bool const value = false;
};
template <class T>
struct is_complex<complex<T>> {
static bool const value = true;
};
//////////////////////////////////////////////////////////////////////////////////////////////////
// Display utilities
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
{
T _r = z.real();
T _i = z.imag();
if (bool(_i)) {
return os << _r << "+i" << _i;
} else {
return os << _r;
}
}
} // end namespace cute

View File

@ -0,0 +1,43 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
namespace cute {
using cutlass::float_e4m3_t;
using cutlass::float_e5m2_t;
} // end namespace cute

View File

@ -0,0 +1,41 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
namespace cute {
using cutlass::half_t;
} // end namespace cute

View File

@ -0,0 +1,129 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include <cute/numeric/integer_subbyte.hpp>
#include <cute/numeric/uint128.hpp>
namespace cute
{
//
// Signed integers
//
using int8_t = std::int8_t;
using int16_t = std::int16_t;
using int32_t = std::int32_t;
using int64_t = std::int64_t;
template <int N> struct int_bit;
template <> struct int_bit< 2> { using type = cute::int2b_t; };
template <> struct int_bit< 4> { using type = cute::int4b_t; };
template <> struct int_bit< 8> { using type = int8_t; };
template <> struct int_bit< 16> { using type = int16_t; };
template <> struct int_bit< 32> { using type = int32_t; };
template <> struct int_bit< 64> { using type = int64_t; };
template <int N>
using int_bit_t = typename int_bit<N>::type;
template <int N>
using int_byte = int_bit<8*N>;
template <int N>
using int_byte_t = typename int_byte<N>::type;
//
// Unsigned integers
//
using uint8_t = std::uint8_t;
using uint16_t = std::uint16_t;
using uint32_t = std::uint32_t;
using uint64_t = std::uint64_t;
template <int N> struct uint_bit;
template <> struct uint_bit< 1> { using type = cute::uint1b_t; };
template <> struct uint_bit< 2> { using type = cute::uint2b_t; };
template <> struct uint_bit< 4> { using type = cute::uint4b_t; };
template <> struct uint_bit< 8> { using type = uint8_t; };
template <> struct uint_bit< 16> { using type = uint16_t; };
template <> struct uint_bit< 32> { using type = uint32_t; };
template <> struct uint_bit< 64> { using type = uint64_t; };
template <> struct uint_bit<128> { using type = cute::uint128_t; };
template <int N>
using uint_bit_t = typename uint_bit<N>::type;
template <int N>
using uint_byte = uint_bit<8*N>;
template <int N>
using uint_byte_t = typename uint_byte<N>::type;
//
// sizeof_bytes
//
template <class T>
struct sizeof_bytes {
static constexpr std::size_t value = sizeof(T);
};
template <class T>
static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
//
// sizeof_bits
//
template <class T>
struct sizeof_bits {
static constexpr std::size_t value = sizeof(T) * 8;
};
template <>
struct sizeof_bits<bool> {
static constexpr std::size_t value = 1;
};
template <int Bits, bool Signed>
struct sizeof_bits<integer_subbyte<Bits,Signed>> {
static constexpr std::size_t value = Bits;
};
template <class T>
static constexpr int sizeof_bits_v = sizeof_bits<T>::value;
} // namespace cute

View File

@ -0,0 +1,139 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <utility> // std::integer_sequence
#include <cute/config.hpp>
namespace cute
{
using std::integer_sequence;
using std::make_integer_sequence;
namespace detail {
template <class T, class S, T Begin>
struct make_integer_range_impl;
template <class T, T... N, T Begin>
struct make_integer_range_impl<T, integer_sequence<T, N...>, Begin> {
using type = integer_sequence<T, N+Begin...>;
};
} // end namespace detail
template <class T, T Begin, T End>
using make_integer_range = typename detail::make_integer_range_impl<
T,
make_integer_sequence<T, (End-Begin > 0) ? (End-Begin) : 0>,
Begin>::type;
//
// Common aliases
//
// int_sequence
template <int... Ints>
using int_sequence = integer_sequence<int, Ints...>;
template <int N>
using make_int_sequence = make_integer_sequence<int, N>;
template <int Begin, int End>
using make_int_range = make_integer_range<int, Begin, End>;
// index_sequence
template <std::size_t... Ints>
using index_sequence = integer_sequence<std::size_t, Ints...>;
template <std::size_t N>
using make_index_sequence = make_integer_sequence<std::size_t, N>;
template <std::size_t Begin, std::size_t End>
using make_index_range = make_integer_range<std::size_t, Begin, End>;
//
// Shortcuts
//
template <int... Ints>
using seq = int_sequence<Ints...>;
template <int N>
using make_seq = make_int_sequence<N>;
template <int Min, int Max>
using make_range = make_int_range<Min, Max>;
template <class Tuple>
using tuple_seq = make_seq<std::tuple_size<std::remove_reference_t<Tuple>>::value>;
} // end namespace cute
//
// Specialize tuple-related functionality for cute::integer_sequence
//
#include <tuple>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
template <std::size_t I, class T, T... Ints>
CUTE_HOST_DEVICE constexpr
std::tuple_element_t<I, integer_sequence<T, Ints...>>
get(integer_sequence<T, Ints...>) {
static_assert(I < sizeof...(Ints), "Index out of range");
return {};
}
} // end namespace cute
namespace std
{
template <class T, T... Ints>
struct tuple_size<cute::integer_sequence<T, Ints...>>
: std::integral_constant<std::size_t, sizeof...(Ints)>
{};
template <std::size_t I, class T, T... Ints>
struct tuple_element<I, cute::integer_sequence<T, Ints...>>
: std::tuple_element<I, std::tuple<cute::integral_constant<T,Ints>...>>
{};
} // end namespace std

View File

@ -0,0 +1,233 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
namespace cute {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, bool Signed = true>
struct integer_subbyte
{
/// Storage type
using Storage = uint8_t;
/// Number of bits
static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte");
/// External type
using xint_t = typename std::conditional<Signed, int, unsigned>::type;
/// Bitmask for truncation from larger integers
static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1);
/// Bitmask for the sign bit
static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1));
//
// Data members
//
Storage storage;
//
// Methods
//
/// No operation
CUTE_HOST_DEVICE constexpr
integer_subbyte() {}
/// Conversion from integer type
CUTE_HOST_DEVICE constexpr
integer_subbyte(int value) // NOTE: Sign extension?
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
CUTE_HOST_DEVICE constexpr
integer_subbyte(unsigned value)
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
/// Convert to int or unsigned
CUTE_HOST_DEVICE constexpr
operator xint_t() const {
if (sign_mask_ & storage) { // Sign extend
return xint_t(storage) | ~xint_t(bits_mask_);
} else {
return xint_t(storage);
}
}
/// Equality
CUTE_HOST_DEVICE constexpr
bool operator==(integer_subbyte const& rhs) const {
return storage == rhs.storage;
}
/// Inequality
CUTE_HOST_DEVICE constexpr
bool operator!=(integer_subbyte const& rhs) const {
return storage != rhs.storage;
}
/// Less than or equal
CUTE_HOST_DEVICE constexpr
bool operator<=(integer_subbyte const& rhs) const {
if (sign_mask_ & storage) {
return !(rhs.storage < storage);
} else {
return storage < rhs.storage;
}
}
/// Less than
CUTE_HOST_DEVICE constexpr
bool operator<(integer_subbyte const& rhs) const {
if (sign_mask_ & storage) {
return !(rhs.storage <= storage);
} else {
return storage < rhs.storage;
}
}
/// Greater than or equal
CUTE_HOST_DEVICE constexpr
bool operator>=(integer_subbyte const& rhs) const {
return !(*this < rhs);
}
/// Greater than
CUTE_HOST_DEVICE constexpr
bool operator>(integer_subbyte const& rhs) const {
return !(*this <= rhs);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// 1-bit unsigned integer type
using uint1b_t = integer_subbyte<1, false>;
/// 2-bit integer type
using int2b_t = integer_subbyte<2, true>;
/// 2-bit unsigned integer type
using uint2b_t = integer_subbyte<2, false>;
/// 4-bit integer type
using int4b_t = integer_subbyte<4, true>;
/// 4-bit unsigned integer type
using uint4b_t = integer_subbyte<4, false>;
/// 1-bit binary type
using bin1_t = bool;
} // namespace cute
///////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__)
#include <limits>
namespace std {
template <>
struct numeric_limits<cute::uint1b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const max() noexcept { return 1; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
template <>
struct numeric_limits<cute::int2b_t> {
CUTE_HOST_DEVICE static constexpr
cute::int2b_t lowest() noexcept { return -2; }
CUTE_HOST_DEVICE static constexpr
cute::int2b_t min() noexcept { return -2; }
CUTE_HOST_DEVICE static constexpr
cute::int2b_t max() noexcept { return 1; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = true;
};
template <>
struct numeric_limits<cute::uint2b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const max() noexcept { return 3; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
template <>
struct numeric_limits<cute::int4b_t> {
CUTE_HOST_DEVICE static constexpr
cute::int4b_t lowest() noexcept { return -8; }
CUTE_HOST_DEVICE static constexpr
cute::int4b_t min() noexcept { return -8; }
CUTE_HOST_DEVICE static constexpr
cute::int4b_t max() noexcept { return 7; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = true;
};
template <>
struct numeric_limits<cute::uint4b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const max() noexcept { return 15; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
} // namespace std
#endif

View File

@ -0,0 +1,414 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/math.hpp>
namespace cute
{
template <class T, T v>
struct constant : std::integral_constant<T,v> {
static constexpr T value = v;
using value_type = T;
using type = constant<T,v>;
CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
};
template <class T, T v>
using integral_constant = constant<T,v>;
template <bool b>
using bool_constant = constant<bool,b>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
//
// Traits
//
// Use std::is_integral<T> to match built-in integral types (int, int64_t, unsigned, etc)
// Use cute::is_integral<T> to match both built-in integral types AND constant<T,t>
template <class T>
struct is_integral : bool_constant<std::is_integral<T>::value> {};
template <class T, T v>
struct is_integral<constant<T,v>> : true_type {};
// is_static detects if an (abstract) value is defined completely by it's type (no members)
template <class T>
struct is_static : bool_constant<std::is_empty<T>::value> {};
// is_constant detects if a type is a constant<T,v> and if v is equal to a value
template <auto n, class T>
struct is_constant : false_type {};
template <auto n, class T, T v>
struct is_constant<n, constant<T,v> > : bool_constant<v == n> {};
template <auto n, class T, T v>
struct is_constant<n, constant<T,v> const > : bool_constant<v == n> {};
template <auto n, class T, T v>
struct is_constant<n, constant<T,v> const&> : bool_constant<v == n> {};
template <auto n, class T, T v>
struct is_constant<n, constant<T,v> &> : bool_constant<v == n> {};
template <auto n, class T, T v>
struct is_constant<n, constant<T,v> &&> : bool_constant<v == n> {};
//
// Specializations
//
template <int v>
using Int = constant<int,v>;
using _m32 = Int<-32>;
using _m24 = Int<-24>;
using _m16 = Int<-16>;
using _m12 = Int<-12>;
using _m10 = Int<-10>;
using _m9 = Int<-9>;
using _m8 = Int<-8>;
using _m7 = Int<-7>;
using _m6 = Int<-6>;
using _m5 = Int<-5>;
using _m4 = Int<-4>;
using _m3 = Int<-3>;
using _m2 = Int<-2>;
using _m1 = Int<-1>;
using _0 = Int<0>;
using _1 = Int<1>;
using _2 = Int<2>;
using _3 = Int<3>;
using _4 = Int<4>;
using _5 = Int<5>;
using _6 = Int<6>;
using _7 = Int<7>;
using _8 = Int<8>;
using _9 = Int<9>;
using _10 = Int<10>;
using _12 = Int<12>;
using _16 = Int<16>;
using _24 = Int<24>;
using _32 = Int<32>;
using _64 = Int<64>;
using _96 = Int<96>;
using _128 = Int<128>;
using _192 = Int<192>;
using _256 = Int<256>;
using _512 = Int<512>;
using _1024 = Int<1024>;
using _2048 = Int<2048>;
using _4096 = Int<4096>;
using _8192 = Int<8192>;
/***************/
/** Operators **/
/***************/
#define CUTE_LEFT_UNARY_OP(OP) \
template <class T, T t> \
CUTE_HOST_DEVICE constexpr \
constant<decltype(OP t), (OP t)> \
operator OP (constant<T,t>) { \
return {}; \
}
#define CUTE_RIGHT_UNARY_OP(OP) \
template <class T, T t> \
CUTE_HOST_DEVICE constexpr \
constant<decltype(t OP), (t OP)> \
operator OP (constant<T,t>) { \
return {}; \
}
#define CUTE_BINARY_OP(OP) \
template <class T, T t, class U, U u> \
CUTE_HOST_DEVICE constexpr \
constant<decltype(t OP u), (t OP u)> \
operator OP (constant<T,t>, constant<U,u>) { \
return {}; \
}
CUTE_LEFT_UNARY_OP(+);
CUTE_LEFT_UNARY_OP(-);
CUTE_LEFT_UNARY_OP(~);
CUTE_LEFT_UNARY_OP(!);
CUTE_LEFT_UNARY_OP(*);
CUTE_BINARY_OP( +);
CUTE_BINARY_OP( -);
CUTE_BINARY_OP( *);
CUTE_BINARY_OP( /);
CUTE_BINARY_OP( %);
CUTE_BINARY_OP( &);
CUTE_BINARY_OP( |);
CUTE_BINARY_OP( ^);
CUTE_BINARY_OP(<<);
CUTE_BINARY_OP(>>);
CUTE_BINARY_OP(&&);
CUTE_BINARY_OP(||);
CUTE_BINARY_OP(==);
CUTE_BINARY_OP(!=);
CUTE_BINARY_OP( >);
CUTE_BINARY_OP( <);
CUTE_BINARY_OP(>=);
CUTE_BINARY_OP(<=);
#undef CUTE_BINARY_OP
#undef CUTE_LEFT_UNARY_OP
#undef CUTE_RIGHT_UNARY_OP
//
// Mixed static-dynamic special cases
//
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator*(constant<T, 0>, U) {
return {};
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator*(U, constant<T, 0>) {
return {};
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator/(constant<T, 0>, U) {
return {};
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(U, constant<T, 1>) {
return {};
}
template <class U, class T,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(U, constant<T,-1>) {
return {};
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator%(constant<T, 0>, U) {
return {};
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator&(constant<T, 0>, U) {
return {};
}
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
constant<T, 0>
operator&(U, constant<T, 0>) {
return {};
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, false>
operator&&(constant<T, t>, U) {
return {};
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, false>
operator&&(U, constant<T, t>) {
return {};
}
template <class T, class U, T t,
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, true>
operator||(constant<T, t>, U) {
return {};
}
template <class T, class U, T t,
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
CUTE_HOST_DEVICE constexpr
constant<bool, true>
operator||(U, constant<T, t>) {
return {};
}
//
// Named functions from math.hpp
//
#define CUTE_NAMED_UNARY_FN(OP) \
template <class T, T t> \
CUTE_HOST_DEVICE constexpr \
constant<decltype(OP(t)), OP(t)> \
OP (constant<T,t>) { \
return {}; \
}
#define CUTE_NAMED_BINARY_FN(OP) \
template <class T, T t, class U, U u> \
CUTE_HOST_DEVICE constexpr \
constant<decltype(OP(t,u)), OP(t,u)> \
OP (constant<T,t>, constant<U,u>) { \
return {}; \
} \
\
template <class T, T t, class U, \
__CUTE_REQUIRES(std::is_integral<U>::value)> \
CUTE_HOST_DEVICE constexpr \
auto \
OP (constant<T,t>, U u) { \
return OP(t,u); \
} \
\
template <class T, class U, U u, \
__CUTE_REQUIRES(std::is_integral<T>::value)> \
CUTE_HOST_DEVICE constexpr \
auto \
OP (T t, constant<U,u>) { \
return OP(t,u); \
}
CUTE_NAMED_UNARY_FN(abs);
CUTE_NAMED_UNARY_FN(signum);
CUTE_NAMED_UNARY_FN(has_single_bit);
CUTE_NAMED_BINARY_FN(max);
CUTE_NAMED_BINARY_FN(min);
CUTE_NAMED_BINARY_FN(shiftl);
CUTE_NAMED_BINARY_FN(shiftr);
CUTE_NAMED_BINARY_FN(gcd);
CUTE_NAMED_BINARY_FN(lcm);
#undef CUTE_NAMED_UNARY_FN
#undef CUTE_NAMED_BINARY_FN
//
// Other functions
//
template <class T, T t, class U, U u>
CUTE_HOST_DEVICE constexpr
constant<decltype(t / u), t / u>
safe_div(constant<T, t>, constant<U, u>) {
static_assert(t % u == 0, "Static safe_div requires t % u == 0");
return {};
}
template <class T, T t, class U,
__CUTE_REQUIRES(std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(constant<T, t>, U u) {
return t / u;
}
template <class T, class U, U u,
__CUTE_REQUIRES(std::is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(T t, constant<U, u>) {
return t / u;
}
// cute::true_type prefers standard conversion to std::true_type
// over user-defined conversion to bool
template <class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
decltype(auto)
conditional_return(std::true_type, TrueType&& t, FalseType&&) {
return static_cast<TrueType&&>(t);
}
// cute::false_type prefers standard conversion to std::false_type
// over user-defined conversion to bool
template <class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
decltype(auto)
conditional_return(std::false_type, TrueType&&, FalseType&& f) {
return static_cast<FalseType&&>(f);
}
// TrueType and FalseType must have a common type
template <class TrueType, class FalseType>
CUTE_HOST_DEVICE constexpr
auto
conditional_return(bool b, TrueType const& t, FalseType const& f) {
return b ? t : f;
}
//
// Display utilities
//
template <class T, T N>
CUTE_HOST_DEVICE void print(integral_constant<T,N> const&) {
printf("_%d", N);
}
template <class T, T N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant<T,N> const&) {
return os << "_" << N;
}
} // end namespace cute

View File

@ -0,0 +1,319 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <limits>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include <cute/config.hpp>
namespace cute
{
//
// Common Operations
//
template <class T, class U,
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
std::is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
max(T const& t, U const& u) {
return t < u ? u : t;
}
template <class T, class U,
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
std::is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
min(T const& t, U const& u) {
return t < u ? t : u;
}
template <class T,
__CUTE_REQUIRES(std::is_arithmetic<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
abs(T const& t) {
if constexpr (std::is_signed<T>::value) {
return t < T(0) ? -t : t;
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
//
// C++17 <numeric> operations
//
// Greatest common divisor of two integers
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
gcd(T t, U u) {
while (true) {
if (t == 0) { return u; }
u %= t;
if (u == 0) { return t; }
t %= u;
}
}
// Least common multiple of two integers
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
lcm(T const& t, U const& u) {
return (t / gcd(t,u)) * u;
}
//
// C++20 <bit> operations
//
// Checks if a number is an integral power of two
template <class T>
CUTE_HOST_DEVICE constexpr
bool
has_single_bit(T x) {
return x != 0 && (x & (x - 1)) == 0;
}
// Smallest number of bits needed to represent the given value
// bit_width( 0b0000 ) = 0
// bit_width( 0b0001 ) = 1
// bit_width( 0b0010 ) = 2
// bit_width( 0b0011 ) = 2
// bit_width( 0b0100 ) = 3
// bit_width( 0b0101 ) = 3
// bit_width( 0b0110 ) = 3
// bit_width( 0b0111 ) = 3
template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_width(T x) {
static_assert(std::is_unsigned<T>::value, "Only to be used for unsigned types.");
constexpr int N = (std::numeric_limits<T>::digits == 64 ? 6 :
(std::numeric_limits<T>::digits == 32 ? 5 :
(std::numeric_limits<T>::digits == 16 ? 4 :
(std::numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
T r = 0;
for (int i = N - 1; i >= 0; --i) {
T shift = (x > ((T(1) << (T(1) << i))-1)) << i;
x >>= shift;
r |= shift;
}
return r + (x != 0);
}
// Smallest integral power of two not less than the given value
// bit_ceil( 0b00000000 ) = 0b00000001
// bit_ceil( 0b00000001 ) = 0b00000001
// bit_ceil( 0b00000010 ) = 0b00000010
// bit_ceil( 0b00000011 ) = 0b00000100
// bit_ceil( 0b00000100 ) = 0b00000100
// bit_ceil( 0b00000101 ) = 0b00001000
// bit_ceil( 0b00000110 ) = 0b00001000
// bit_ceil( 0b00000111 ) = 0b00001000
// bit_ceil( 0b00001000 ) = 0b00001000
// bit_ceil( 0b00001001 ) = 0b00010000
template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_ceil(T x) {
return x == 0 ? T(1) : (T(1) << bit_width(x - 1));
}
// Largest integral power of two not greater than the given value
// bit_floor( 0b00000000 ) = 0b00000000
// bit_floor( 0b00000001 ) = 0b00000001
// bit_floor( 0b00000010 ) = 0b00000010
// bit_floor( 0b00000011 ) = 0b00000010
// bit_floor( 0b00000100 ) = 0b00000100
// bit_floor( 0b00000101 ) = 0b00000100
// bit_floor( 0b00000110 ) = 0b00000100
// bit_floor( 0b00000111 ) = 0b00000100
// bit_floor( 0b00001000 ) = 0b00001000
// bit_floor( 0b00001001 ) = 0b00001000
template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_floor(T x) {
return x == 0 ? 0 : (T(1) << (bit_width(x) - 1));
}
template <class T>
CUTE_HOST_DEVICE constexpr T rotl(T x, int s);
template <class T>
CUTE_HOST_DEVICE constexpr T rotr(T x, int s);
// Computes the result of circular bitwise left-rotation
template <class T>
CUTE_HOST_DEVICE constexpr
T
rotl(T x, int s) {
constexpr int N = std::numeric_limits<T>::digits;
return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s);
}
// Computes the result of circular bitwise right-rotation
template <class T>
CUTE_HOST_DEVICE constexpr
T
rotr(T x, int s) {
constexpr int N = std::numeric_limits<T>::digits;
return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s);
}
// Counts the number of consecutive 0 bits, starting from the most significant bit
// countl_zero( 0b00000000 ) = 8
// countl_zero( 0b11111111 ) = 0
// countl_zero( 0b00011100 ) = 3
template <class T>
CUTE_HOST_DEVICE constexpr
T
countl_zero(T x) {
return std::numeric_limits<T>::digits - bit_width(x);
}
// Counts the number of consecutive 1 bits, starting from the most significant bit
// countl_one( 0b00000000 ) = 0
// countl_one( 0b11111111 ) = 8
// countl_one( 0b11100011 ) = 3
template <class T>
CUTE_HOST_DEVICE constexpr
T
countl_one(T x) {
return countl_zero(~x);
}
// Counts the number of consecutive 0 bits, starting from the least significant bit
// countr_zero( 0b00000000 ) = 8
// countr_zero( 0b11111111 ) = 0
// countr_zero( 0b00011100 ) = 2
template <class T>
CUTE_HOST_DEVICE constexpr
T
countr_zero(T x) {
return x == 0 ? std::numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
}
// Counts the number of consecutive 1 bits, starting from the least significant bit
// countr_one( 0b00000000 ) = 0
// countr_one( 0b11111111 ) = 8
// countr_one( 0b11100011 ) = 2
template <class T>
CUTE_HOST_DEVICE constexpr
T
countr_one(T x) {
return countr_zero(~x);
}
// Counts the number of 1 bits in an unsigned integer
// popcount( 0b00000000 ) = 0
// popcount( 0b11111111 ) = 8
// popcount( 0b00011101 ) = 4
template <class T>
CUTE_HOST_DEVICE constexpr
int
popcount(T x) {
int c = 0;
while (x) {
++c;
x &= x - 1; // clear the least significant bit set
}
return c;
}
//
// Custom operations
//
// Computes the result of bitwise left-shift
template <class T>
CUTE_HOST_DEVICE constexpr
T
shiftl(T x, int s) {
return s >= 0 ? (x << s) : (x >> -s);
}
// Computes the result of bitwise right-shift
template <class T>
CUTE_HOST_DEVICE constexpr
T
shiftr(T x, int s) {
return s >= 0 ? (x >> s) : (x << -s);
}
// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero.
template <class T,
__CUTE_REQUIRES(std::is_unsigned<T>::value)>
CUTE_HOST_DEVICE constexpr
int
signum(T const& x) {
return T(0) < x;
}
template <class T,
__CUTE_REQUIRES(not std::is_unsigned<T>::value)>
CUTE_HOST_DEVICE constexpr
int
signum(T const& x) {
return (T(0) < x) - (x < T(0));
}
// Safe divide
// @pre t % u == 0
// @result t / u
template <class T, class U,
__CUTE_REQUIRES(std::is_integral<T>::value &&
std::is_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(T const& t, U const& u) {
//assert(t % u == 0);
return t / u;
}
} // namespace cute

View File

@ -0,0 +1,56 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
namespace cute
{
/// Generic fused multiply-add
template <class D, class A, class B, class C>
CUTE_HOST_DEVICE constexpr
void
fma(D& d, A const& a, B const& b, C const& c)
{
d = a * b + c;
}
/// Fused multiply-add for triplets
template <class A, class B, class C>
CUTE_HOST_DEVICE constexpr
void
fma(A const& a, B const& b, C& c)
{
return fma(c, a, b, c);
}
} // end namespace cute

View File

@ -0,0 +1,51 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
namespace cute {
using cutlass::tfloat32_t;
//
// Display utilities
//
CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v)
{
return os << float(v);
}
} // end namespace cute

View File

@ -0,0 +1,259 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#include <cstdlib>
#include <cmath>
#include <type_traits>
#include <stdexcept>
#endif
#include <cute/config.hpp>
/// Optionally enable GCC's built-in type
#if defined(__x86_64) && !defined(__CUDA_ARCH__)
# if defined(__GNUC__) && 0
# define CUTE_UINT128_NATIVE
# elif defined(_MSC_VER)
# define CUTE_INT128_ARITHMETIC
# include <intrin.h>
# endif
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
/////////////////////////////////////////////////////////////////////////////////////////////////
///! Unsigned 128b integer type
struct alignas(16) uint128_t
{
/// Size of one part of the uint's storage in bits
static constexpr int storage_bits_ = 64;
struct hilo
{
uint64_t lo;
uint64_t hi;
};
// Use a union to store either low and high parts or, if present, a built-in 128b integer type.
union
{
struct hilo hilo_;
#if defined(CUTE_UINT128_NATIVE)
unsigned __int128 native;
#endif // defined(CUTE_UINT128_NATIVE)
};
//
// Methods
//
/// Default ctor
CUTE_HOST_DEVICE constexpr
uint128_t() : hilo_{0, 0} {}
/// Constructor from uint64
CUTE_HOST_DEVICE constexpr
uint128_t(uint64_t lo_) : hilo_{lo_, 0} {}
/// Constructor from two 64b unsigned integers
CUTE_HOST_DEVICE constexpr
uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {}
/// Optional constructor from native value
#if defined(CUTE_UINT128_NATIVE)
uint128_t(unsigned __int128 value) : native(value) { }
#endif
/// Lossily cast to uint64
CUTE_HOST_DEVICE constexpr
explicit operator uint64_t() const
{
return hilo_.lo;
}
template <class Dummy = bool>
CUTE_HOST_DEVICE constexpr
static void exception()
{
//static_assert(sizeof(Dummy) == 0, "Not implemented exception!");
//abort();
//printf("uint128 not implemented!\n");
}
/// Add
CUTE_HOST_DEVICE constexpr
uint128_t operator+(uint128_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native + rhs.native;
#else
y.hilo_.lo = hilo_.lo + rhs.hilo_.lo;
y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo));
#endif
return y;
}
/// Subtract
CUTE_HOST_DEVICE constexpr
uint128_t operator-(uint128_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native - rhs.native;
#else
y.hilo_.lo = hilo_.lo - rhs.hilo_.lo;
y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo);
#endif
return y;
}
/// Multiply by unsigned 64b integer yielding 128b integer
CUTE_HOST_DEVICE constexpr
uint128_t operator*(uint64_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native * rhs;
#elif defined(CUTE_INT128_ARITHMETIC)
// Multiply by the low part
y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi);
// Add the high part and ignore the overflow
uint64_t overflow;
y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow);
#else
exception();
#endif
return y;
}
/// Divide 128b operation by 64b operation yielding a 64b quotient
CUTE_HOST_DEVICE constexpr
uint64_t operator/(uint64_t const& divisor) const
{
uint64_t quotient = 0;
#if defined(CUTE_UINT128_NATIVE)
quotient = uint64_t(native / divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
uint64_t remainder = 0;
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return quotient;
}
/// Divide 128b operation by 64b operation yielding a 64b quotient
CUTE_HOST_DEVICE constexpr
uint64_t operator%(uint64_t const& divisor) const
{
uint64_t remainder = 0;
#if defined(CUTE_UINT128_NATIVE)
remainder = uint64_t(native % divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
(void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return remainder;
}
/// Computes the quotient and remainder in a single method.
CUTE_HOST_DEVICE constexpr
uint64_t divmod(uint64_t &remainder, uint64_t divisor) const
{
uint64_t quotient = 0;
#if defined(CUTE_UINT128_NATIVE)
quotient = uint64_t(native / divisor);
remainder = uint64_t(native % divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return quotient;
}
/// Left-shifts a 128b unsigned integer
CUTE_HOST_DEVICE constexpr
uint128_t operator<<(int sh) const
{
if (sh == 0) {
return *this;
}
else if (sh >= storage_bits_) {
return uint128_t(0, hilo_.lo << (sh - storage_bits_));
}
else {
return uint128_t(
(hilo_.lo << sh),
(hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh))
);
}
}
/// Right-shifts a 128b unsigned integer
CUTE_HOST_DEVICE constexpr
uint128_t operator>>(int sh) const
{
if (sh == 0) {
return *this;
}
else if (sh >= storage_bits_) {
return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0);
}
else {
return uint128_t(
(hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)),
(hilo_.hi >> sh)
);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute
/////////////////////////////////////////////////////////////////////////////////////////////////

322
include/cute/pointer.hpp Normal file
View File

@ -0,0 +1,322 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/math.hpp>
namespace cute
{
//
// has_dereference to determine if a type is a pointer concept
//
template <class T, class = void>
struct has_dereference : std::false_type {
};
template <class T>
struct has_dereference<T, void_t<decltype(*std::declval<T>())>> : std::true_type {
};
//
// Pointer categories
//
template <class T>
struct is_gmem : false_type {};
template <class T>
struct is_smem : false_type {};
// Anything that is not gmem or smem is rmem
template <class T>
struct is_rmem : bool_constant< not (is_gmem<T>::value || is_smem<T>::value)> {};
//
// A very simplified wrapper for pointers -- use for constructing tagged pointers
//
template <class T, class DerivedType>
struct device_ptr
{
using value_type = T;
CUTE_HOST_DEVICE constexpr
device_ptr(T* ptr) : ptr_(ptr) {}
CUTE_HOST_DEVICE constexpr
T* get() const { return ptr_; }
CUTE_HOST_DEVICE constexpr
T& operator*() const { return *ptr_; }
template <class Index>
CUTE_HOST_DEVICE constexpr
T& operator[](Index const& i) const { return ptr_[i]; }
template <class Index>
CUTE_HOST_DEVICE constexpr
DerivedType operator+(Index const& i) const { return {ptr_ + i}; }
CUTE_HOST_DEVICE constexpr friend
std::ptrdiff_t operator-(device_ptr<T,DerivedType> const& a,
device_ptr<T,DerivedType> const& b) {
return a.ptr_ - b.ptr_;
}
T* ptr_;
};
//
// gmem_ptr
//
template <class T>
struct gmem_ptr : device_ptr<T, gmem_ptr<T>> {
using device_ptr<T, gmem_ptr<T>>::device_ptr;
};
template <class T>
CUTE_HOST_DEVICE constexpr
gmem_ptr<T>
make_gmem_ptr(T* ptr) {
return {ptr};
}
template <class T>
CUTE_HOST_DEVICE constexpr
gmem_ptr<T>
make_gmem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
struct is_gmem<gmem_ptr<T>> : true_type {};
//
// smem_ptr
//
template <class T>
struct smem_ptr : device_ptr<T, smem_ptr<T>> {
using device_ptr<T, smem_ptr<T>>::device_ptr;
};
template <class T>
CUTE_HOST_DEVICE constexpr
smem_ptr<T>
make_smem_ptr(T* ptr) {
return {ptr};
}
template <class T>
CUTE_HOST_DEVICE constexpr
smem_ptr<T>
make_smem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
struct is_smem<smem_ptr<T>> : true_type {};
//
// rmem_ptr
//
template <class T>
struct rmem_ptr : device_ptr<T, rmem_ptr<T>> {
using device_ptr<T, rmem_ptr<T>>::device_ptr;
};
template <class T>
CUTE_HOST_DEVICE constexpr
rmem_ptr<T>
make_rmem_ptr(T* ptr) {
return {ptr};
}
template <class T>
CUTE_HOST_DEVICE constexpr
rmem_ptr<T>
make_rmem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
struct is_rmem<rmem_ptr<T>> : true_type {};
//
// counting iterator -- quick and dirty
//
struct counting
{
using index_type = int;
using value_type = index_type;
CUTE_HOST_DEVICE constexpr
counting() : n_(0) {}
CUTE_HOST_DEVICE constexpr
counting(index_type const& n) : n_(n) {}
CUTE_HOST_DEVICE constexpr
index_type operator[](index_type const& i) const { return n_ + i; }
CUTE_HOST_DEVICE constexpr
index_type const& operator*() const { return n_; }
CUTE_HOST_DEVICE constexpr
counting operator+(index_type const& i) const { return {n_ + i}; }
CUTE_HOST_DEVICE constexpr
counting& operator++() { ++n_; return *this; }
CUTE_HOST_DEVICE constexpr
bool operator==(counting const& other) const { return n_ == other.n_; }
CUTE_HOST_DEVICE constexpr
bool operator!=(counting const& other) const { return n_ != other.n_; }
CUTE_HOST_DEVICE constexpr
bool operator< (counting const& other) const { return n_ < other.n_; }
index_type n_;
};
//
// recast
//
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(T* ptr) {
return reinterpret_cast<NewT*>(ptr);
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(T const* ptr) {
return reinterpret_cast<NewT const*>(ptr);
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(gmem_ptr<T> const& ptr) {
return make_gmem_ptr(recast<NewT>(ptr.ptr_));
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(gmem_ptr<T const> const& ptr) {
return make_gmem_ptr(recast<NewT const>(ptr.ptr_));
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(smem_ptr<T> const& ptr) {
return make_smem_ptr(recast<NewT>(ptr.ptr_));
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(smem_ptr<T const> const& ptr) {
return make_smem_ptr(recast<NewT const>(ptr.ptr_));
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(rmem_ptr<T> const& ptr) {
return make_rmem_ptr(recast<NewT>(ptr.ptr_));
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast(rmem_ptr<T const> const& ptr) {
return make_rmem_ptr(recast<NewT const>(ptr.ptr_));
}
//
// Display utilities
//
template <class T>
CUTE_HOST_DEVICE void print(T const* const ptr)
{
printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr);
}
template <class T>
CUTE_HOST_DEVICE void print(gmem_ptr<T> const& ptr)
{
printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(smem_ptr<T> const& ptr)
{
printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(rmem_ptr<T> const& ptr)
{
printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> const& ptr)
{
return os << "gmem_ptr_" << int(8*sizeof(T)) << "b";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr<T> const& ptr)
{
return os << "smem_ptr_" << int(8*sizeof(T)) << "b";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> const& ptr)
{
return os << "rmem_ptr_" << int(8*sizeof(T)) << "b";
}
} // end namespace cute

411
include/cute/stride.hpp Normal file
View File

@ -0,0 +1,411 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/int_tuple.hpp>
namespace cute
{
/** crd2idx maps a coordinate within <Shape,Stride> to an index
* This is computed as follows:
* [coord, shape, and stride are all integers => step forward by stride]
* op(c, s, d) => c * d
* [coord is integer, shape and stride are tuple => divmod coord for each mode]
* op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D))
* [coord, shape, and stride are all tuples => consider each mode independently]
* op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D))
*/
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
crd2idx(Coord const& coord,
Shape const& shape,
Stride const& stride);
namespace detail {
template <class Coord, class Shape, class Stride, int... Is>
CUTE_HOST_DEVICE constexpr
auto
crd2idx_ttt(Coord const& coord,
Shape const& shape,
Stride const& stride, seq<Is...>)
{
return (... + crd2idx(get<Is>(coord), get<Is>(shape), get<Is>(stride)));
}
template <class CInt, class STuple, class DTuple, int I0, int... Is>
CUTE_HOST_DEVICE constexpr
auto
crd2idx_itt(CInt const& coord,
STuple const& shape,
DTuple const& stride, seq<I0,Is...>)
{
if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter
return crd2idx(coord, get<I0>(shape), get<I0>(stride));
} else { // General case
return crd2idx(coord % product(get<I0>(shape)), get<I0>(shape), get<I0>(stride))
+ crd2idx_itt(coord / product(get<I0>(shape)), shape, stride, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class Coord, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
crd2idx(Coord const& coord,
Shape const& shape,
Stride const& stride)
{
if constexpr (is_tuple<Coord>::value) {
if constexpr (is_tuple<Shape>::value) { // tuple tuple tuple
static_assert(tuple_size<Coord>::value == tuple_size< Shape>::value, "Mismatched Ranks");
static_assert(tuple_size<Coord>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2idx_ttt(coord, shape, stride, tuple_seq<Coord>{});
} else { // tuple "int" "int"
static_assert(sizeof(Coord) == 0, "Invalid parameters");
}
} else {
if constexpr (is_tuple<Shape>::value) { // "int" tuple tuple
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return detail::crd2idx_itt(coord, shape, stride, tuple_seq<Shape>{});
} else { // "int" "int" "int"
return coord * stride;
}
}
CUTE_GCC_UNREACHABLE;
}
//
// If we know Stride is default [CompactColMajor], then we can take shortcuts
//
namespace detail {
template <class CTuple, class STuple, int I0, int... Is>
CUTE_HOST_DEVICE constexpr
auto
crd2idx_horner(CTuple const& coord,
STuple const& shape, seq<I0,Is...>)
{
if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter
return get<I0>(coord);
} else { // General case
return get<I0>(coord) + get<I0>(shape) * crd2idx_horner(coord, shape, seq<Is...>{});
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
auto
crd2idx(Coord const& coord,
Shape const& shape)
{
static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks");
if constexpr (is_tuple<Shape>::value) {
// Flatten and apply Horner's method
auto flat_coord = flatten(coord);
auto flat_shape = flatten(shape);
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
} else {
return coord;
}
CUTE_GCC_UNREACHABLE;
}
/** idx2crd splits an index to a coordinate within <Shape,Stride>.
*
* This is computed as follows:
* [index, shape, and stride are all integers => determine 1D coord]
* op(i, s, d) => (i / d) % s
* [index is integer, shape and stride are tuple => determine component for each mode]
* op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...)
* [index, shape, and stride are all tuples => consider each mode independently]
* op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D)))
*
* NOTE: This only works for compact shape+stride layouts. A more general version would
* apply to all surjective layouts
*/
template <class Index, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
idx2crd(Index const& idx,
Shape const& shape,
Stride const& stride)
{
if constexpr (is_tuple<Index>::value) {
if constexpr (is_tuple<Shape>::value) { // tuple tuple tuple
static_assert(tuple_size<Index>::value == tuple_size< Shape>::value, "Mismatched Ranks");
static_assert(tuple_size<Index>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); });
} else { // tuple "int" "int"
static_assert(sizeof(Index) == 0, "Invalid parameters");
}
} else {
if constexpr (is_tuple<Shape>::value) {
if constexpr (is_tuple<Stride>::value) { // "int" tuple tuple
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); });
} else { // "int" tuple "int"
return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); });
}
} else { // "int" "int" "int"
return (idx / stride) % shape;
}
}
CUTE_GCC_UNREACHABLE;
}
//
// If we know Stride is default [CompactColMajor], then we can take shortcuts
//
//(idx / 1) % s0
//(idx / s0) % s1
//(idx / (s0 * s1)) % s2
//...
template <class Index, class Shape>
CUTE_HOST_DEVICE constexpr
auto
idx2crd(Index const& idx,
Shape const& shape)
{
if constexpr (is_tuple<Index>::value) {
if constexpr (is_tuple<Shape>::value) { // tuple tuple
static_assert(tuple_size<Index>::value == tuple_size<Shape>::value, "Mismatched Ranks");
return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); });
} else { // tuple "int"
static_assert(sizeof(Index) == 0, "Invalid parameters");
}
} else {
if constexpr (is_tuple<Shape>::value) { // "int" tuple
return idx2crd(idx, shape, compact_col_major(shape));
} else { // "int" "int"
return idx;
}
}
CUTE_GCC_UNREACHABLE;
}
//
// crd2crd
//
template <class Coord, class SShape, class DShape>
CUTE_HOST_DEVICE constexpr
auto
crd2crd(Coord const& coord,
SShape const& src_shape,
DShape const& dst_shape)
{
if constexpr (is_tuple<Coord>::value && is_tuple<SShape>::value && is_tuple<DShape>::value) {
static_assert(tuple_size<Coord>::value == tuple_size<SShape>::value, "Mismatched Ranks");
static_assert(tuple_size<Coord>::value == tuple_size<DShape>::value, "Mismatched Ranks");
return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); });
} else {
// assert(size(src_shape) == size(dst_shape))
return idx2crd(crd2idx(coord, src_shape), dst_shape);
}
CUTE_GCC_UNREACHABLE;
}
//
// Compact Major
//
// General tag for common layouts and dispatching
struct GenColMajor {};
struct GenRowMajor {};
template <class Shape, class Current = Int<1>, class Major = GenColMajor>
CUTE_HOST_DEVICE constexpr
auto
compact_major(Shape const& shape,
Current const& current = {},
Major const& major = {});
namespace detail {
template <class Shape, class Current, int... Is>
CUTE_HOST_DEVICE constexpr
auto
compact_major_ti(Shape const& shape,
Current const& current,
GenColMajor const& major, seq<Is...>)
{
return cute::make_tuple(compact_major(get<Is>(shape), current * product<0,Is>(shape), major)...);
}
template <class Shape, class Current, int... Is>
CUTE_HOST_DEVICE constexpr
auto
compact_major_ti(Shape const& shape,
Current const& current,
GenRowMajor const& major, seq<Is...>)
{
constexpr int E = tuple_size<Shape>::value;
return cute::make_tuple(compact_major(get<Is>(shape), current * product<Is+1,E>(shape), major)...);
}
} // end namespace detail
template <class Shape, class Current, class Major>
CUTE_HOST_DEVICE constexpr
auto
compact_major(Shape const& shape,
Current const& current,
Major const& major)
{
if constexpr (is_tuple<Shape>::value) {
if constexpr (is_tuple<Current>::value) { // tuple tuple
static_assert(tuple_size<Shape>::value == tuple_size<Current>::value, "Mismatched Ranks");
return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); });
} else { // tuple int
return detail::compact_major_ti(shape, current, major, tuple_seq<Shape>{});
}
} else {
if constexpr (is_tuple<Current>::value) { // int tuple
static_assert(sizeof(Shape) == 0, "Invalid parameters");
} else { // int int
if constexpr (is_constant<1, Shape>::value) {
return Int<0>{}; // If current is dynamic, this could save a reg
} else {
return current;
}
}
}
CUTE_GCC_UNREACHABLE;
}
//
// Compact Col Major
//
template <class Shape, class Current = Int<1>>
CUTE_HOST_DEVICE constexpr
auto
compact_col_major(Shape const& shape,
Current const& current = {})
{
return compact_major(shape, current, GenColMajor{});
}
template <class Shape>
using ColMajor = decltype(compact_col_major(std::declval<Shape>()));
//
// Compact Row Major
//
template <class Shape, class Current = Int<1>>
CUTE_HOST_DEVICE constexpr
auto
compact_row_major(Shape const& shape,
Current const& current = {})
{
return compact_major(shape, current, GenRowMajor{});
}
template <class Shape>
using RowMajor = decltype(compact_row_major(std::declval<Shape>()));
//
// Compact Order -- compute a compact stride based on an ordering of the modes
//
namespace detail {
template <class Shape, class Order, class OrigShape, class OrigOrder>
CUTE_HOST_DEVICE constexpr
auto
compact_order(Shape const& shape, Order const& order,
OrigShape const& orig_shape, OrigOrder const& orig_order)
{
if constexpr (is_tuple<Order>::value) {
return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); });
} else {
auto d = product(transform(orig_shape, orig_order,
[&](auto const& s, auto const& o) {
return conditional_return(o < order, product(s), Int<1>{});
}));
return compact_col_major(shape, d);
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
template <class Shape, class Order>
CUTE_HOST_DEVICE constexpr
auto
compact_order(Shape const& shape, Order const& order)
{
static_assert(is_congruent<Shape,Order>::value, "Need congruence of shape and order.");
return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order));
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
compact_order(Shape const& shape, GenColMajor const& major)
{
return compact_major(shape, Int<1>{}, major);
}
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
compact_order(Shape const& shape, GenRowMajor const& major)
{
return compact_major(shape, Int<1>{}, major);
}
} // end namespace cute

497
include/cute/swizzle.hpp Normal file
View File

@ -0,0 +1,497 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/math.hpp>
namespace cute
{
// A generic Swizzle functor
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
* ^--^ MBase is the number of least-sig bits to keep constant
* ^-^ ^-^ BBits is the number of bits in the mask
* ^---------^ SShift is the distance to shift the YYY mask
* (pos shifts YYY to the right, neg shifts YYY to the left)
*
* e.g. Given
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
* the result is
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
*/
template <int BBits, int MBase, int SShift = BBits>
struct Swizzle
{
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
static_assert(num_base >= 0, "MBase must be positive.");
static_assert(num_bits >= 0, "BBits must be positive.");
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");
// using 'int' type here to avoid unintentially casting to unsigned... unsure.
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;
static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{});
template <class Offset,
__CUTE_REQUIRES(is_integral<Offset>::value)>
CUTE_HOST_DEVICE constexpr static
auto
apply(Offset const& offset)
{
return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY
}
template <class Offset,
__CUTE_REQUIRES(is_integral<Offset>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator()(Offset const& offset) const
{
return apply(offset);
}
};
// Translation for legacy SwizzleXor
// TODO: Deprecate
template <uint32_t BBits, uint32_t MBase, uint32_t SShift = 0>
using SwizzleXor = Swizzle<BBits, MBase, SShift+BBits>;
//
// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1>
// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5>
//
template <uint32_t Y, uint32_t Z>
CUTE_HOST_DEVICE constexpr
auto
make_swizzle()
{
constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits
constexpr uint32_t BY = popcount(Z); // Number of swizzle bits
static_assert(BZ == BY, "Number of bits in Y and Z don't match");
constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y
constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z
constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32;
constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros
static_assert((Y | Z) == Swizzle<BZ,M,S>::swizzle_code, "Something went wrong.");
return Swizzle<BZ,M,S>{};
}
template <int B0, int M0, int S0,
int B1, int M1, int S1>
CUTE_HOST_DEVICE constexpr
auto
composition(Swizzle<B0,M0,S0>, Swizzle<B1,M1,S1>)
{
static_assert(S0 == S1, "Can only merge swizzles of the same shift.");
constexpr uint32_t Y = Swizzle<B0,M0,S0>::yyy_msk::value ^ Swizzle<B1,M1,S1>::yyy_msk::value;
constexpr uint32_t Z = Swizzle<B0,M0,S0>::zzz_msk::value ^ Swizzle<B1,M1,S1>::zzz_msk::value;
return make_swizzle<Y,Z>();
//return ComposedFn<Swizzle<B0,M0,S0>, Swizzle<B1,M1,S1>>{};
}
//
// Upcast and Downcast
//
template <int N, int B, int M, int S>
CUTE_HOST_DEVICE constexpr
auto
upcast(Swizzle<B,M,S> const& swizzle)
{
static_assert(has_single_bit(N), "N must be a power of two");
constexpr int log2_n = bit_width(uint32_t(N)) - 1;
constexpr int NewM = M - log2_n;
if constexpr (NewM >= 0) {
return Swizzle<B,NewM,S>{};
} else {
return Swizzle<cute::max(B+NewM,0), 0, S>{};
}
CUTE_GCC_UNREACHABLE;
}
template <int N, int B, int M, int S>
CUTE_HOST_DEVICE constexpr
auto
downcast(Swizzle<B,M,S> const& swizzle)
{
static_assert(has_single_bit(N), "N must be a power of two");
constexpr int log2_n = bit_width(uint32_t(N)) - 1;
return Swizzle<B,(M + log2_n),S>{};
}
template <class OldType, class NewType,
int B, int M, int S>
CUTE_HOST_DEVICE constexpr
auto
recast(Swizzle<B,M,S> const& swizzle)
{
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
return swizzle;
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(swizzle);
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(swizzle);
}
}
//
// Utility for slicing and swizzle "offsets"
//
// For swizzle functions, it is often needed to keep track of which bits are
// consumed and which bits are free. Furthermore, it is useful to know whether
// each of these bits is known statically or dynamically.
// MixedBits is an integer class where some bits are known statically and some
// bits are known dynamically. These sets of bits are disjoint and it is known
// statically which bits are known dynamically.
// MixedBits can only be manipulated through bitwise operations
// Abstract value: StaticInt | (dynamic_int_ & StaticFlags)
template <uint32_t StaticInt = 0,
class DynamicType = uint32_t,
uint32_t StaticFlags = 0> // 0: static, 1: dynamic
struct MixedBits
{
// Representation invariants
static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits.");
static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits.");
// assert((dynamic_int_ & ~F) == 0);
DynamicType dynamic_int_;
};
template <class S, S s, class DynamicType, class F, F f>
CUTE_HOST_DEVICE constexpr
auto
make_mixed_bits(constant<S,s> const&, DynamicType const& d, constant<F,f> const&)
{
static_assert(is_integral<DynamicType>::value);
if constexpr (is_static<DynamicType>::value) {
static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed.");
return constant<S,s>{} | (d & constant<F,f>{}); // Just return a static int
} else if constexpr (f == 0) {
return constant<S,s>{}; // Just return a static int
} else {
return MixedBits<s, DynamicType, f>{d & f}; // MixedBits
}
CUTE_GCC_UNREACHABLE;
}
//
// Explicit conversion for now -- consider casting on plus or minus
//
template <uint32_t S, class D, uint32_t F>
CUTE_HOST_DEVICE constexpr
auto
to_integral(MixedBits<S,D,F> const& m)
{
//return S | (m.dynamic_int_ & F);
return S | m.dynamic_int_;
}
// Any cute::is_integral
template <class I, __CUTE_REQUIRES(cute::is_integral<I>::value)>
CUTE_HOST_DEVICE constexpr
auto
to_integral(I const& i)
{
return i;
}
//
// Operators
//
// Equality
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator==(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
{
return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0));
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator==(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
{
return m == s;
}
// Bitwise AND
template <uint32_t S0, class D0, uint32_t F0,
uint32_t S1, class D1, uint32_t F1>
CUTE_HOST_DEVICE constexpr
auto
operator&(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
{
// Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F)
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
// S1D1F1
// 0X0 | 0X0 | 0X0 | 0X0 | 0X0 |
// 001 | 0X0 | 001 | 001 | 001 |
// 011 | 0X0 | 001 | 011 | 011 |
// 1X0 | 0X0 | 001 | 011 | 1X0 |
return make_mixed_bits(constant<uint32_t,S0 & S1>{},
//(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_),
((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_),
constant<uint32_t,(S1 & F0) | (S0 & F1) | (F0 & F1)>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator&(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
{
return make_mixed_bits(constant<uint32_t,S0 & S1>{},
m.dynamic_int_,
constant<uint32_t,S1 & F0>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator&(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
{
return m & s;
}
// Bitwise OR
template <uint32_t S0, class D0, uint32_t F0,
uint32_t S1, class D1, uint32_t F1>
CUTE_HOST_DEVICE constexpr
auto
operator|(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
{
// Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F)
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
// S1D1F1
// 0X0 | 0X0 | 001 | 011 | 1X0 |
// 001 | 001 | 001 | 011 | 1X0 |
// 011 | 011 | 011 | 011 | 1X0 |
// 1X0 | 1X0 | 1X0 | 1X0 | 1X0 |
return make_mixed_bits(constant<uint32_t,S0 | S1>{},
((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_),
constant<uint32_t,(~S0 & F1) | (~S1 & F0)>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator|(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
{
return make_mixed_bits(constant<uint32_t,S0 | S1>{},
m.dynamic_int_,
constant<uint32_t,~S1 & F0>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator|(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
{
return m | s;
}
// Bitwise XOR
template <uint32_t S0, class D0, uint32_t F0,
uint32_t S1, class D1, uint32_t F1>
CUTE_HOST_DEVICE constexpr
auto
operator^(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
{
// Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F)
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
// S1D1F1
// 0X0 | 0X0 | 001 | 011 | 1X0 |
// 001 | 001 | 001 | 011 | 011 |
// 011 | 011 | 011 | 001 | 001 |
// 1X0 | 1X0 | 011 | 001 | 0X0 |
return make_mixed_bits(constant<uint32_t,(~S0 & S1 & ~F0) | (S0 & ~S1 & ~F1)>{},
(S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_),
constant<uint32_t,F0 | F1>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator^(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
{
return make_mixed_bits(constant<uint32_t,(~S0 & S1 & ~F0) | (S0 & ~S1)>{},
(S0 | m.dynamic_int_) ^ S1,
constant<uint32_t,F0>{});
}
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
operator^(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
{
return m ^ s;
}
//
// upcast and downcast
//
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
CUTE_HOST_DEVICE constexpr
auto
safe_div(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const& s)
{
static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two.");
return make_mixed_bits(safe_div(constant<uint32_t,S0>{}, s),
safe_div(m.dynamic_int_, s),
safe_div(constant<uint32_t,F0>{}, s));
}
template <uint32_t N, uint32_t S0, class D0, uint32_t F0>
CUTE_HOST_DEVICE constexpr
auto
upcast(MixedBits<S0,D0,F0> const& m)
{
static_assert(has_single_bit(N), "Only divide MixedBits by powers of two.");
return safe_div(m, constant<uint32_t,N>{});
}
template <uint32_t N, class T, __CUTE_REQUIRES(cute::is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
upcast(T const& m)
{
return safe_div(m, constant<uint32_t,N>{});
}
template <uint32_t N, uint32_t S0, class D0, uint32_t F0>
CUTE_HOST_DEVICE constexpr
auto
downcast(MixedBits<S0,D0,F0> const& m)
{
static_assert(has_single_bit(N), "Only scale MixedBits by powers of two.");
return make_mixed_bits(constant<uint32_t,S0 * N>{},
m.dynamic_int_ * N,
constant<uint32_t,F0 * N>{});
}
template <uint32_t N, class T, __CUTE_REQUIRES(cute::is_integral<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
downcast(T const& m)
{
return m * constant<uint32_t, N>{};
}
//
// Convert a Pow2Layout+Coord to a MixedBits
//
template <class Shape, class Stride, class Coord>
CUTE_HOST_DEVICE constexpr
auto
to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord)
{
if constexpr (is_tuple<Shape>::value && is_tuple<Stride>::value && is_tuple<Coord>::value) {
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched ranks");
static_assert(tuple_size<Shape>::value == tuple_size<Coord >::value, "Mismatched ranks");
return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); },
[](auto const&... a) { return (a ^ ...); });
} else if constexpr (is_integral<Shape>::value && is_integral<Stride>::value && is_integral<Coord>::value) {
static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride.");
return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride);
} else {
static_assert(is_integral<Shape>::value && is_integral<Stride>::value && is_integral<Coord>::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral).");
}
CUTE_GCC_UNREACHABLE;
}
template <class Layout, class Coord>
CUTE_HOST_DEVICE constexpr
auto
to_mixed_bits(Layout const& layout, Coord const& coord)
{
return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape()));
}
//
// Display utilities
//
template <uint32_t S, class D, uint32_t F>
CUTE_HOST_DEVICE void print(MixedBits<S,D,F> const& m)
{
printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m));
}
template <uint32_t S, class D, uint32_t F>
CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits<S,D,F> const& m)
{
return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m);
}
template <int B, int M, int S>
CUTE_HOST_DEVICE void print(Swizzle<B,M,S> const&)
{
print("S<%d,%d,%d>", B, M, S);
}
template <int B, int M, int S>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle<B,M,S> const&)
{
return os << "S<" << B << "," << M << "," << S << ">";
}
} // end namespace cute

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,282 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/swizzle.hpp>
#include <cute/swizzle_layout.hpp>
#include <cute/tensor.hpp>
#include <cute/pointer.hpp>
#include <cute/container/array.hpp>
#include <cute/numeric/int.hpp>
/* This implements a swizzle pointer of the form
* InvolutionFn o PtrAdd
* where the InvolutionFn need not be linear.
*
* This differs subtly from swizzle_layout because the smem pointer is used
* as the offset. That means that swizzle_layout will implement position-independent
* swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors.
* Arch chose to design hardware with position-dependent swizzles.
*
* For clarity:
* NormalLayout : DeRef <- PtrAdd <- [Layout]
* ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout]
* SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout
*
* Furthermore, for known swizzles, this pointer attempts to decay itself
* to a normal-pointer with a new layout containing dynamic or static strides.
* This is possible by determining the subdomain of the InvolutionFn
* that is identity and testing if the Layout's codomain is contained
* within it.
*/
namespace cute
{
template <class T, class Swizzle>
struct smem_ptr_swizzle
{
static_assert(std::is_empty<Swizzle>::value, "Swizzle can't have state.");
CUTE_HOST_DEVICE constexpr
T* get() const
{
return ptr_;
}
CUTE_HOST_DEVICE constexpr static
Swizzle get_swizzle()
{
return {};
}
CUTE_HOST_DEVICE constexpr static
T* apply_swizzle(T* ptr)
{
return reinterpret_cast<T*>(Swizzle::apply(reinterpret_cast<std::uintptr_t>(ptr)));
}
CUTE_HOST_DEVICE constexpr
T& operator*() const
{
return *apply_swizzle(get());
}
template <class Int>
CUTE_HOST_DEVICE constexpr
T& operator[](Int const& i) const
{
return *apply_swizzle(get() + i);
}
template <class Int>
CUTE_HOST_DEVICE constexpr
smem_ptr_swizzle operator+(Int const& i) const
{
return {ptr_ + i};
}
T* ptr_;
};
template <class T, class S>
struct is_smem<smem_ptr_swizzle<T,S>> : true_type {};
// Make a swizzle pointer
template <class T, class Swizzle>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(T* ptr, Swizzle const& swizzle)
{
return smem_ptr_swizzle<T,Swizzle>{ptr};
}
// A model of a nullptr smem_ptr<T> with B == sizeof_bits<T>::value
// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr
template <int Bits>
struct smem_ptr_flag_bits : Int<0> {};
using smem_ptr_flag = smem_ptr_flag_bits<1>;
// A flagged construction method to transform ComposedLayout
// Make a swizzle pointer tensor and check that the intended type size matches
template <class T, class Swizzle, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(smem_ptr<T> const& ptr,
ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
{
static_assert(B == sizeof_bits<T>::value, "Expected a B-bit pointer type.");
return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()),
layout.layout_fn());
}
// Specialization for immediate decay
template <class T, int M, int S, class LShape, class LStride>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(smem_ptr_swizzle<T,Swizzle<0,M,S>>& p, Layout<LShape,LStride> const& layout)
{
return make_tensor(make_smem_ptr(p.ptr_), layout);
}
template <class T, int M, int S, class LShape, class LStride>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(smem_ptr_swizzle<T,Swizzle<0,M,S>> const& p, Layout<LShape,LStride> const& layout)
{
return make_tensor(make_smem_ptr(p.ptr_), layout);
}
// NOTE: To preserve smem_ptr_flag_bits under recast ops
template <int N, class Swizzle, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(layout.swizzle_fn(), smem_ptr_flag_bits<B*N>{}, upcast<N>(layout.layout_fn()));
}
template <int N, class Swizzle, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
downcast(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(layout.swizzle_fn(), smem_ptr_flag_bits<B/N>{}, downcast<N>(layout.layout_fn()));
}
//
// Recast
// Swizzle operates on the pointer address, so it doesn't care about the type
//
template <class NewT, class T, class Swizzle>
CUTE_HOST_DEVICE constexpr
auto
recast(smem_ptr_swizzle<T,Swizzle> const& ptr)
{
return smem_ptr_swizzle<NewT,Swizzle>{recast<NewT>(ptr.ptr_)};
}
template <class NewT, class T, class Swizzle>
CUTE_HOST_DEVICE constexpr
auto
recast(smem_ptr_swizzle<T const,Swizzle> const& ptr)
{
return smem_ptr_swizzle<NewT const,Swizzle>{recast<NewT const>(ptr.ptr_)};
}
//
// Conversion with swizzle_layout
//
template <class T, class Swizzle, int B, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_layout(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(recast<uint_bit_t<8>,uint_bit_t<B>>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn());
}
template <class T, class Swizzle, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout> const& tensor)
{
{
uint32_t address = cast_smem_ptr_to_uint(tensor.data().get());
uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code);
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle
}
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<sizeof_bits_v<T>>>(tensor.data().get_swizzle());
return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout()));
}
template <class T, class Swizzle, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout>& tensor)
{
{
uint32_t address = cast_smem_ptr_to_uint(tensor.data().get());
uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code);
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle
}
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<sizeof_bits_v<T>>>(tensor.data().get_swizzle());
return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout()));
}
template <class T, class Swizzle, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout>&& tensor)
{
return as_position_independent_swizzle_tensor(tensor);
}
//
// Print
//
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
template <class Swizzle, int B, class Layout>
CUTE_HOST_DEVICE
void
print_latex(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
{
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<B>>(layout.swizzle_fn());
print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn()));
}
template <int B>
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> const& ptr)
{
printf("smem_ptr_%db(unset)", B);
}
template <class T, int B, int M, int S>
CUTE_HOST_DEVICE void print(smem_ptr_swizzle<T,Swizzle<B,M,S>> const& ptr)
{
printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get());
}
template <class T, int B, int M, int S>
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle<T,Swizzle<B,M,S>> const&)
{
return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b";
}
} // end namespace cute

900
include/cute/tensor.hpp Normal file
View File

@ -0,0 +1,900 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/container/tuple.hpp>
#include <cute/container/array_aligned.hpp>
#include <cute/container/array_subbyte.hpp>
#include <cute/container/type_list.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/layout.hpp>
#include <cute/tile.hpp>
#include <cute/pointer.hpp>
namespace cute
{
//
// Engine -- owning or non-owning data store
//
// concept Engine {
// using value_type = ;
// iterator begin();
// };
template <class T, int N>
using ArrayEngine = typename std::conditional<(sizeof_bits<T>::value % 8 == 0),
array_aligned<T,N>,
array_subbyte<T,N>>::type;
template <class Iterator>
struct ViewEngine
{
using value_type = typename cute::remove_cvref<decltype(*std::declval<Iterator>())>::type;
using iterator = Iterator;
iterator storage_;
CUTE_HOST_DEVICE constexpr
iterator const&
begin() const {
return storage_;
}
CUTE_HOST_DEVICE constexpr
iterator&
begin() {
return storage_;
}
};
template <class Iter>
struct is_rmem<ViewEngine<Iter>> : is_rmem<Iter> {};
template <class Iter>
struct is_smem<ViewEngine<Iter>> : is_smem<Iter> {};
template <class Iter>
struct is_gmem<ViewEngine<Iter>> : is_gmem<Iter> {};
template <class Iterator>
struct ConstViewEngine
{
using value_type = typename cute::remove_cvref<decltype(*std::declval<Iterator>())>::type;
using iterator = Iterator;
iterator storage_;
CUTE_HOST_DEVICE constexpr
iterator const&
begin() const {
return storage_;
}
};
template <class Iter>
struct is_rmem<ConstViewEngine<Iter>> : is_rmem<Iter> {};
template <class Iter>
struct is_smem<ConstViewEngine<Iter>> : is_smem<Iter> {};
template <class Iter>
struct is_gmem<ConstViewEngine<Iter>> : is_gmem<Iter> {};
//
// Tensor
//
template <class Engine, class Layout>
struct Tensor
{
using value_type = typename Engine::value_type;
//using pointer = typename engine_traits<Engine>::pointer;
//using const_pointer = typename engine_traits<Engine>::const_pointer;
//using reference = typename engine_traits<Engine>::reference;
//using const_reference = typename engine_traits<Engine>::const_reference;
using engine_type = Engine;
using layout_type = Layout;
CUTE_HOST_DEVICE constexpr
Tensor() {}
template <class Ptr>
CUTE_HOST_DEVICE constexpr
Tensor(Ptr const& ptr, Layout const& layout)
: rep_(layout, ptr) {
}
//
// Accessors
//
static constexpr int rank = Layout::rank;
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor() const {
return *this;
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout() const {
return get<0>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() const {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() const {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape() const {
return layout().shape();
}
CUTE_HOST_DEVICE constexpr
auto
size() const {
return cute::size(shape());
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride() const {
return layout().stride();
}
//
// Indexing op() and op[]
//
// Index into this tensor like an array by computing the offset via layout()
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) const {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
// op() convenience function for multi-dimensional coordinates
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) {
return operator()(make_coord(c0,c1,cs...));
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
//
// Compose
//
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) {
return make_tensor(data(), layout().compose(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) const {
return make_tensor(data(), layout().compose(layouts...));
}
//
// Tile
//
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) {
return make_tensor(data(), layout().tile(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) const {
return make_tensor(data(), layout().tile(layouts...));
}
//
// Utility
//
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_1d_coord(Int const& linear_idx) const {
return layout().get_1d_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_hier_coord(Int const& linear_idx) const {
return layout().get_hier_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_flat_coord(Int const& linear_idx) const {
return layout().get_flat_coord(linear_idx);
}
cute::tuple<layout_type, engine_type> rep_;
};
template <class Layout>
struct is_tensor : false_type {};
template <class Engine, class Layout>
struct is_tensor<Tensor<Engine,Layout>> : true_type {};
template <class Engine, class Layout>
struct is_rmem<Tensor<Engine,Layout>> : is_rmem<Engine> {};
template <class Engine, class Layout>
struct is_smem<Tensor<Engine,Layout>> : is_smem<Engine> {};
template <class Engine, class Layout>
struct is_gmem<Tensor<Engine,Layout>> : is_gmem<Engine> {};
//
// Make an owning Tensor that will allocate a static array
//
template <class T, class Layout,
__CUTE_REQUIRES(is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Layout const& layout)
{
static_assert(is_static<Layout>::value, "Dynamic owning tensors not supported");
using Engine = ArrayEngine<T, cosize_v<Layout>>;
return Tensor<Engine,Layout>();
}
// e.g. make_tensor<double>(12)
template <class T, class LayoutArg, class... LayoutArgs,
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(LayoutArg const& arg, LayoutArgs const&... args)
{
return make_tensor<T>(make_layout(arg, args...));
}
//
// Make a non-owning Tensor that will use a pointer (view)
//
template <class Iterator, class Layout,
__CUTE_REQUIRES(has_dereference<Iterator>::value &&
is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& iter, Layout const& layout)
{
using Engine = ViewEngine<Iterator>;
return Tensor<Engine,Layout>(iter, layout);
}
// e.g. make_tensor(vec.data(), 12)
template <class Iterator, class LayoutArg, class... LayoutArgs,
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args)
{
return make_tensor(iter, make_layout(arg, args...));
}
//
// make_tensor_like -- make a register tensor the same type and shape as another
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Tensor<Engine,Layout> const& tensor)
{
using value_type = typename Tensor<Engine,Layout>::value_type;
return make_tensor<value_type>(tensor.shape());
}
//
// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Tensor<Engine,Layout> const& tensor)
{
using value_type = typename Tensor<Engine,Layout>::value_type;
return make_tensor<value_type>(make_layout_like(tensor.layout()));
}
//
// make_identity_tensor
//
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_identity_tensor(Shape const& shape)
{
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))),
make_identity_layout(shape));
}
//
// Utilities
//
// Return the subtensor of a mode
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor(Tensor&& tensor)
{
return std::forward<Tensor>(tensor);
}
template <int I, int... Is, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor(Tensor&& tensor)
{
return make_tensor(std::forward<Tensor>(tensor).data(), get<I,Is...>(tensor.layout()));
}
// Return the subtensor of a range of modes
template <int B, int E, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
take(Tensor&& tensor)
{
return make_tensor(std::forward<Tensor>(tensor).data(), take<B,E>(tensor.layout()));
}
// Return the layout of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout(Tensor<Engine,Layout> const& tensor)
{
return layout<Is...>(tensor.layout());
}
// Return the shape of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape(Tensor<Engine,Layout> const& tensor)
{
return shape<Is...>(tensor.layout());
}
// Return the stride of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride(Tensor<Engine,Layout> const& tensor)
{
return stride<Is...>(tensor.layout());
}
// Return the number of elements in a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
size(Tensor<Engine,Layout> const& tensor)
{
return size<Is...>(tensor.layout());
}
// Return the rank of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
rank(Tensor<Engine,Layout> const& tensor)
{
return rank<Is...>(tensor.layout());
}
// Return the depth of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
depth(Tensor<Engine, Layout> const& tensor)
{
return depth<Is...>(tensor.layout());
}
//
// Operations to manipulate Tensors like a Layout
//
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
flatten(Tensor&& tensor)
{
return make_tensor(std::forward<Tensor>(tensor).data(), flatten(tensor.layout()));
}
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor&& tensor)
{
return make_tensor(std::forward<Tensor>(tensor).data(), coalesce(tensor.layout()));
}
template <class Tensor, class Profile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor&& tensor, Profile const& profile)
{
return make_tensor(std::forward<Tensor>(tensor).data(), coalesce(tensor.layout(), profile));
}
// Group the modes [B,E) into a single mode
// e.g. group<2,4>(make_tensor<int>(Layout<Shape<_1,_2,_3,_4,_5,_6>>{}))
// => make_tensor<int>(Layout<Shape<_1,_2,Shape<_3,_4>,_5,_6>>{})
template <int B, int E, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
group_modes(Tensor&& tensor)
{
return make_tensor(std::forward<Tensor>(tensor).data(),
group<B,E>(tensor.layout()));
}
//
// Recast
//
// NOTE: This is very dangerous to do
// -- doesn't check dynamic integer divisibility
// -- doesn't check alignment
// A tagged version for dispatching
template <class NewType, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
recast(Tensor&& tensor, type_list<NewType>)
{
using OldType = typename remove_cvref_t<Tensor>::value_type;
auto old_layout = tensor.layout();
auto new_layout = recast<OldType,NewType>(old_layout);
// If this is an upcast of a normal Layout with static negative strides, then offset as well
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
return make_tensor(recast<NewType>(std::forward<Tensor>(tensor).data() + offset), new_layout);
} else {
return make_tensor(recast<NewType>(std::forward<Tensor>(tensor).data() ), new_layout);
}
CUTE_GCC_UNREACHABLE;
}
template <class NewType, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
recast(Tensor&& tensor)
{
return recast(std::forward<Tensor>(tensor), type_list<NewType>{});
}
//
// max_common_vector
//
/* Return Int<N> such that N is the maximum number of continguous elements
* that logically correspond in the tensors of @a a and @a b. This is,
* the number of elements that could reasonably be vectorized into a single load/store.
*
* @returns Int<N> with N >= 0
*
* A return value of Int<0> indicates that no such conclusion can be made and no
* vectorization should be attempted.
*/
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(Tensor<SrcEngine,SrcLayout> const& a,
Tensor<DstEngine,DstLayout> const& b)
{
using SrcType = typename Tensor<SrcEngine,SrcLayout>::value_type;
using DstType = typename Tensor<DstEngine,DstLayout>::value_type;
using SrcRef = decltype(*(a.data()));
using DstRef = decltype(*(b.data()));
// Determine if vectorization candidates at all
if constexpr (// Should be the same value_types, else the copy is also performing a cast
sizeof(SrcType) == sizeof(DstType) &&
// The types should be trivially copyable so that vectorization is valid
std::is_trivially_copyable<SrcType>::value &&
std::is_trivially_copyable<DstType>::value &&
// Should be load/storing real data, rather than implicit iterators or such
std::is_reference<SrcRef>::value &&
std::is_reference<DstRef>::value)
{
return max_common_vector(a.layout(), b.layout());
} else {
return Int<0>{};
}
CUTE_GCC_UNREACHABLE;
}
//
// Key algebraic operations
//
template <class Tensor, class Tile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Tensor && tensor,
Tile const& tile)
{
return make_tensor(std::forward<Tensor>(tensor).data(),
logical_divide(tensor.layout(), tile));
}
// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b))
template <class Tensor, class Tile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(Tensor && tensor,
Tile const& tile) // Layout or Tile<Layout...>
{
return make_tensor(std::forward<Tensor>(tensor).data(),
zipped_divide(tensor.layout(), tile));
}
// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b)
template <class Tensor, class Tile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(Tensor && tensor,
Tile const& tile) // Layout or Tile<Layout...>
{
return make_tensor(std::forward<Tensor>(tensor).data(),
tiled_divide(tensor.layout(), tile));
}
// logical_product on a Tensor doesn't make sense since it often increases cosize
//
// Logicial Divide utilities: local_partition and local_tile
//
template <class Tensor, class Tile, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_partition(Tensor && tensor,
Tile const& tile,
Coord const& coord)
{
constexpr int R1 = decltype(rank(tensor))::value;
// Split the modes of tensor according to the modes of tile
// zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...))
// The_coord is the coord into the first mode, flatten the rest
return zipped_divide(std::forward<Tensor>(tensor), tile)(coord, repeat<R1>(_));
}
template <class Tensor, class Tile, class Coord, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_partition(Tensor && tensor,
Tile const& tile,
Coord const& coord,
Projection const& proj)
{
return local_partition(std::forward<Tensor>(tensor),
dice(proj, tile),
dice(proj, coord));
}
// Special case with Layout and Integral that extracts the coord first
// e.g. local_partition(tensor, ThrLayout, threadIdx.x)
template <class Tensor, class LShape, class LStride, class Index,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value &&
is_integral<Index>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile,
Index const& index)
{
return local_partition(std::forward<Tensor>(tensor),
product_each(shape(tile)),
tile.get_flat_coord(index));
}
// Special case with Layout and Integral that extracts the coord first
// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{})
template <class Tensor, class LShape, class LStride, class Index, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value &&
is_integral<Index>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile,
Index const& index,
Projection const& proj)
{
return local_partition(std::forward<Tensor>(tensor),
dice(proj, product_each(shape(tile))),
dice(proj, tile).get_flat_coord(index));
}
template <class Tensor, class Tile, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_tile(Tensor && tensor,
Tile const& tile,
Coord const& coord)
{
constexpr int R0 = decltype(rank(tile))::value;
constexpr int R1 = decltype(rank(tensor))::value;
// Split the modes of tensor according to the modes of tile
// zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...))
// The padded_coord is the coord into the second mode, flatten the rest
return zipped_divide(std::forward<Tensor>(tensor), tile)(repeat<R0>(_), append<R1>(coord,_));
}
template <class Tensor, class Tile, class Coord, class Proj,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_tile(Tensor && tensor,
Tile const& tile,
Coord const& coord,
Proj const& proj)
{
return local_tile(std::forward<Tensor>(tensor),
dice(proj, tile),
dice(proj, coord));
}
//
// Display utilities
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
{
auto format = get_format(tensor(0));
using type = typename decltype(format)::type;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
printf(format.format, format.digits, type(tensor(m)));
printf("\n");
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
printf(format.format, format.digits, type(tensor(m,n)));
}
printf("\n");
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor(tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor(tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p));
}
}
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
{
print(tensor.layout()); print("\n");
print_tensor(tensor);
}
template <class Engine, class Layout>
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
int digits = 9;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
os << std::setw(digits) << tensor(m) << std::endl;
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
os << std::setw(digits) << tensor(m,n);
}
os << std::endl;
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor_os(os, tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
print_tensor_os(os, tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor_os(os, tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
print_tensor_os(os, tensor(_,_,_,p));
}
}
return os;
}
template <class Engine, class Layout>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
os << tensor.layout() << std::endl;
return print_tensor_os(os, tensor);
}
} // end namespace cute
//
// Extended Engines
//
#include <cute/swizzle_ptr.hpp>
//
// Tensor Algorithms
//
#include <cute/algorithm/tensor_algorithms.hpp>
#include <cute/algorithm/fill.hpp>
#include <cute/algorithm/clear.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/axpby.hpp>
#include <cute/algorithm/gemm.hpp>

View File

@ -0,0 +1,63 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
template <class T>
struct ConstantTensor
{
template <class... Coords>
CUTE_HOST_DEVICE constexpr
T const&
operator()(Coords const&...) const {
return val_;
}
T val_;
};
struct TrivialPredTensor
{
template <class... Coords>
CUTE_HOST_DEVICE constexpr
true_type
operator()(Coords const&...) const {
return {};
}
};
} // end namespace cute

58
include/cute/tile.hpp Normal file
View File

@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/layout.hpp>
namespace cute
{
//
// A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores
//
template <class... Layouts>
using Tile = tuple<Layouts...>;
template <class Tile>
using is_tile = is_tuple<Tile>;
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
make_tile(Layouts const&... layouts)
{
return Tile<Layouts...>(layouts...);
}
} // end namespace cute

148
include/cute/underscore.hpp Normal file
View File

@ -0,0 +1,148 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integer_sequence.hpp>
namespace cute
{
// For slicing
struct Underscore : Int<0> {};
CUTE_INLINE_CONSTANT Underscore _;
// Treat Underscore as an integral like integral_constant
template <>
struct is_integral<Underscore> : true_type {};
template <class T>
struct is_underscore : false_type {};
template <>
struct is_underscore<Underscore> : true_type {};
// Tuple trait for detecting static member element
template <class Tuple, class Elem, class Enable = void>
struct has_elem : false_type {};
template <class Elem>
struct has_elem<Elem, Elem> : true_type {};
template <class Tuple, class Elem>
struct has_elem<Tuple, Elem, std::enable_if_t<is_tuple<Tuple>::value> >
: has_elem<Tuple, Elem, tuple_seq<Tuple> > {};
template <class Tuple, class Elem, int... Is>
struct has_elem<Tuple, Elem, seq<Is...>>
: disjunction<has_elem<std::tuple_element_t<Is, Tuple>, Elem>...> {};
// Tuple trait for detecting static member element
template <class Tuple, class Elem, class Enable = void>
struct all_elem : false_type {};
template <class Elem>
struct all_elem<Elem, Elem> : true_type {};
template <class Tuple, class Elem>
struct all_elem<Tuple, Elem, std::enable_if_t<is_tuple<Tuple>::value> >
: all_elem<Tuple, Elem, tuple_seq<Tuple> > {};
template <class Tuple, class Elem, int... Is>
struct all_elem<Tuple, Elem, seq<Is...>>
: conjunction<all_elem<std::tuple_element_t<Is, Tuple>, Elem>...> {};
// Tuple trait for detecting Underscore member
template <class Tuple>
using has_underscore = has_elem<Tuple, Underscore>;
template <class Tuple>
using all_underscore = all_elem<Tuple, Underscore>;
template <class Tuple>
using has_int1 = has_elem<Tuple, Int<1>>;
template <class Tuple>
using has_int0 = has_elem<Tuple, Int<0>>;
//
// Slice keeps only the elements of Tuple B that are paired with an Underscore
//
template <class A, class B>
CUTE_HOST_DEVICE constexpr
auto
slice(A const& a, B const& b)
{
if constexpr (is_tuple<A>::value) {
static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks");
return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); });
} else if constexpr (is_underscore<A>::value) {
return cute::tuple<B>{b};
} else {
return cute::tuple<>{};
}
CUTE_GCC_UNREACHABLE;
}
//
// Dice keeps only the elements of Tuple B that are paired with an Int
//
template <class A, class B>
CUTE_HOST_DEVICE constexpr
auto
dice(A const& a, B const& b)
{
if constexpr (is_tuple<A>::value) {
static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks");
return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); });
} else if constexpr (is_underscore<A>::value) {
return cute::tuple<>{};
} else {
return cute::tuple<B>{b};
}
CUTE_GCC_UNREACHABLE;
}
//
// Display utilities
//
CUTE_HOST_DEVICE void print(Underscore const&) {
printf("_");
}
CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) {
return os << "_";
}
} // end namespace cute

153
include/cute/util/debug.hpp Normal file
View File

@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/**
* \file
* \brief Debugging and logging functionality
*/
#include <cuda_runtime_api.h>
#include <cute/config.hpp>
namespace cute
{
/******************************************************************************
* Debug and logging macros
******************************************************************************/
/**
* Formats and prints the given message to stdout
*/
#if !defined(CUTE_LOG)
# if !defined(__CUDA_ARCH__)
# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__)
# else
# define CUTE_LOG(format, ...) \
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
blockIdx.x, blockIdx.y, blockIdx.z, \
threadIdx.x, threadIdx.y, threadIdx.z, \
__VA_ARGS__);
# endif
#endif
/**
* Formats and prints the given message to stdout only if DEBUG is defined
*/
#if !defined(CUTE_LOG_DEBUG)
# ifdef DEBUG
# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__)
# else
# define CUTE_LOG_DEBUG(format, ...)
# endif
#endif
/**
* \brief Perror macro with exit
*/
#if !defined(CUTE_ERROR_EXIT)
# define CUTE_ERROR_EXIT(e) \
do { \
cudaError_t code = (e); \
if (code != cudaSuccess) { \
fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \
__FILE__, __LINE__, #e, \
cudaGetErrorName(code), cudaGetErrorString(code)); \
fflush(stderr); \
exit(0); \
} \
} while (0)
#endif
#if !defined(CUTE_CHECK_LAST)
# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize())
#endif
#if !defined(CUTE_CHECK_ERROR)
# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e)
#endif
// A dummy function that uses compilation failure to print a type
template <class T>
CUTE_HOST_DEVICE
void
print_type(T&&) {
static_assert(sizeof(T) < 0, "Printing type T.");
}
//
// Device-specific helpers
//
// e.g.
// if (thread0()) print(...);
// if (block0()) print(...);
// if (thread(42)) print(...);
CUTE_HOST_DEVICE
bool
thread(int tid, int bid)
{
#if defined(__CUDA_ARCH__)
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid)
&& ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid);
#else
return true;
#endif
}
CUTE_HOST_DEVICE
bool
thread(int tid)
{
return thread(tid, 0);
}
CUTE_HOST_DEVICE
bool
thread0()
{
return thread(0,0);
}
CUTE_HOST_DEVICE
bool
block0()
{
#if defined(__CUDA_ARCH__)
return !(blockIdx.x | blockIdx.y | blockIdx.z);
#else
return true;
#endif
}
} // end namespace cute

140
include/cute/util/print.hpp Normal file
View File

@ -0,0 +1,140 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <type_traits>
#include <cute/config.hpp>
//
// CUDA compatible print and printf
//
namespace cute
{
CUTE_HOST_DEVICE
int
num_digits(int x)
{
return (x < 10 ? 1 :
(x < 100 ? 2 :
(x < 1000 ? 3 :
(x < 10000 ? 4 :
(x < 100000 ? 5 :
(x < 1000000 ? 6 :
(x < 10000000 ? 7 :
(x < 100000000 ? 8 :
(x < 1000000000 ? 9 :
10)))))))));
}
template <class T>
struct format_and_size {
using type = T;
char const* format;
int digits;
};
CUTE_HOST_DEVICE
format_and_size<int>
get_format(bool) {
return {"%*d", 3};
}
CUTE_HOST_DEVICE
format_and_size<int32_t>
get_format(int32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint32_t>
get_format(uint32_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<int64_t>
get_format(int64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<uint64_t>
get_format(uint64_t) {
return {"%*d", 5};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(half_t) {
return {"%*.2f", 8};
}
CUTE_HOST_DEVICE
format_and_size<float>
get_format(float) {
return {"%*.2e", 10};
}
CUTE_HOST_DEVICE
format_and_size<double>
get_format(double) {
return {"%*.3e", 11};
}
//
// print dispatcher
//
CUTE_HOST_DEVICE
void
print(char const& c) {
printf("%c", c);
}
template <class T,
__CUTE_REQUIRES(std::is_integral<T>::value)>
CUTE_HOST_DEVICE
void
print(T const& a) {
printf("%d", int(a));
}
template <class... T>
CUTE_HOST_DEVICE
void
print(char const* format, T const&... t) {
printf(format, t...);
}
} // end namespace cute

View File

@ -0,0 +1,101 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <type_traits>
#include <cute/config.hpp>
#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr
#define __CUTE_REQUIRES_V(...) typename std::enable_if<decltype((__VA_ARGS__))::value>::type* = nullptr
namespace cute
{
using std::conjunction;
using std::conjunction_v;
using std::disjunction;
using std::disjunction_v;
using std::negation;
using std::negation_v;
using std::void_t;
// C++20
// using std::remove_cvref;
template <class T>
struct remove_cvref {
using type = std::remove_cv_t<std::remove_reference_t<T>>;
};
// C++20
// using std::remove_cvref_t;
template <class T>
using remove_cvref_t = typename remove_cvref<T>::type;
//
// is_valid
//
namespace detail {
template <class F, class... Args, class = decltype(std::declval<F&&>()(std::declval<Args&&>()...))>
CUTE_HOST_DEVICE constexpr auto
is_valid_impl(int) { return std::true_type{}; }
template <class F, class... Args>
CUTE_HOST_DEVICE constexpr auto
is_valid_impl(...) { return std::false_type{}; }
template <class F>
struct is_valid_fn {
template <class... Args>
CUTE_HOST_DEVICE constexpr auto
operator()(Args&&...) const { return is_valid_impl<F, Args&&...>(int{}); }
};
} // end namespace detail
template <class F>
CUTE_HOST_DEVICE constexpr auto
is_valid(F&&) {
return detail::is_valid_fn<F&&>{};
}
template <class F, class... Args>
CUTE_HOST_DEVICE constexpr auto
is_valid(F&&, Args&&...) {
return detail::is_valid_impl<F&&, Args&&...>(int{});
}
} // end namespace cute

View File

@ -0,0 +1,404 @@
/***************************************************************************************************
* Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are not permit-
* ted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Barrier Operations on SM90+
*/
#pragma once
#include <cutlass/arch/memory_sm75.h>
#include <cute/arch/cluster_sm90.hpp>
namespace cutlass {
/// @brief
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12)
#define CUDA_BARRIER_ENABLED 1
#else
#define CUDA_BARRIER_ENABLED 0
#endif
class NamedBarrier {
// Data Members:
// Range = [1 , NUM_THREADS_PER_CTA]
// Range % warp-size (i.e 32) == 0
uint32_t const num_threads_;
// Range : [0, 15]
uint32_t const id_;
public:
CUTLASS_DEVICE
NamedBarrier(uint32_t num_threads, uint32_t id = 0)
: num_threads_(num_threads), id_(id) {}
CUTLASS_DEVICE
void arrive_and_wait() const {
NamedBarrier::arrive_and_wait(num_threads_, id_);
}
CUTLASS_DEVICE
void arrive() const {
NamedBarrier::arrive(num_threads_, id_);
}
CUTLASS_DEVICE
void sync() const {
NamedBarrier::arrive_and_wait();
}
// Static variants
CUTLASS_DEVICE
static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) {
#if CUDA_BARRIER_ENABLED
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTLASS_DEVICE
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
#if CUDA_BARRIER_ENABLED
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTLASS_DEVICE
static void sync(uint32_t num_threads, uint32_t barrier_id) {
NamedBarrier::arrive_and_wait(num_threads, barrier_id);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour.
// This is an extension to the Ampere AW barriers
// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20).
struct ClusterBarrier {
using ValueType = uint64_t;
protected:
// Can never be initializated - can only be aliased to smem
ValueType barrier_;
public:
CUTLASS_DEVICE
ClusterBarrier() = delete;
CUTLASS_DEVICE
void init(uint32_t arrive_count) const {
ClusterBarrier::init(&this->barrier_, arrive_count);
}
CUTLASS_DEVICE
uint32_t test_wait(uint32_t phase, uint32_t pred=true) const {
return ClusterBarrier::test_wait(&this->barrier_, phase, pred);
}
CUTLASS_DEVICE
void wait(uint32_t phase) const {
ClusterBarrier::wait(&this->barrier_, phase);
}
// Barrier arrive on local smem
CUTLASS_DEVICE
void arrive() const {
ClusterBarrier::arrive(&this->barrier_);
}
// Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive)
CUTLASS_DEVICE
void arrive(uint32_t cta_id, uint32_t pred = true ) const {
ClusterBarrier::arrive(&this->barrier_, cta_id, pred);
}
//
// Static Versions
//
CUTLASS_DEVICE
static void init(ValueType const* smem_ptr, uint32_t arrive_count) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.init.shared.b64 [%1], %0; \n"
"}"
:
: "r"(arrive_count), "r"(smem_addr));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Static version of wait - in case we don't want to burn a register
CUTLASS_DEVICE
static void wait(ValueType const* smem_ptr, uint32_t phase) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
// Arbitrarily large timer value after which try-wait expires and re-tries.
uint32_t ticks = 0x989680;
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
:
: "r"(smem_addr), "r"(phase), "r"(ticks));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTLASS_DEVICE
static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t waitComplete;
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
".reg .pred P2; \n\t"
"setp.eq.u32 P2, %3, 1;\n\t"
"@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
: "r"(smem_addr), "r"(phase), "r"(pred));
return waitComplete;
#else
asm volatile ("brkpt;\n" ::);
#endif
return 0;
}
// Static Predicated version of the above - in case we know the address.
CUTLASS_DEVICE
static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
".reg .b32 remAddr32;\n\t"
"setp.eq.u32 p, %2, 1;\n\t"
"@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
"}"
:
: "r"(smem_addr), "r"(cta_id), "r"(pred));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Barrier arrive on local smem
CUTLASS_DEVICE
static void arrive(ValueType const* smem_ptr) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint64_t state = 0;
asm volatile(
"{\n\t"
"mbarrier.arrive.shared.b64 %1, [%0];\n\t"
"}"
:
: "r"(smem_addr), "l"(state));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
CUTLASS_DEVICE
static void invalidate(ValueType const* smem_ptr) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.ival.shared.b64 [%0]; \n\t"
"}"
:
: "r"(smem_addr));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// SM90 also introduces a new type of cluster-barrier which supports sync.
// not just based on Arrive Count, but also transaction count (in bytes)
struct ClusterTransactionBarrier : public ClusterBarrier {
CUTLASS_DEVICE
ClusterTransactionBarrier() = delete;
// Performs an arrive operation + bytes reset
CUTLASS_DEVICE
void arrive_and_reset_bytes(uint32_t transaction_bytes) const {
ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes);
}
// Performs an arrive operation + bytes reset
CUTLASS_DEVICE
void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const {
ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true);
}
CUTLASS_DEVICE
void commit(uint32_t transaction_bytes, uint32_t pred = 1) const {
uint32_t cta_rank = cute::block_rank_in_cluster();
ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred);
}
CUTLASS_DEVICE
void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const {
ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred);
}
//
// Static Versions
//
// Performs an arrive operation + bytes reset
CUTLASS_DEVICE
static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t"
"}"
:
: "r"(transaction_bytes), "r"(smem_addr));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Performs an arrive operation + bytes reset for a remote cta_id in a Cluster
CUTLASS_DEVICE
static void arrive_and_reset_bytes(
ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
".reg .b32 remAddr32;\n\t"
"setp.eq.u32 p, %2, 1;\n\t"
"@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t"
"}"
:
: "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Performs an bytes reset without doing an arrive operation
CUTLASS_DEVICE
static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"mbarrier.expect_tx.shared.b64 [%1], %0; \n\t"
"}"
:
: "r"(transaction_bytes), "r"(smem_addr));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Increments transaction bytes in the barrier
CUTLASS_DEVICE
static void commit(
ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) {
#if CUDA_BARRIER_ENABLED
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
smem_addr = cute::set_block_rank(smem_addr, dst_cta_id);
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %2, 1;\n\t"
"@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;"
"}"
:
: "r"(transaction_bytes), "r"(smem_addr), "r"(pred));
#else
asm volatile ("brkpt;\n" ::);
#endif
}
};
// Helps with visibility of barrier init operations across warps / cta / cluster
// Available as a separate function so as to batch inits across barriers and fence once
// Note : It must be composed with an appropriate sync instruction with the right scope
// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait()
CUTLASS_DEVICE
void fence_barrier_init() {
#if CUDA_BARRIER_ENABLED
asm volatile(
"{\n\t"
"fence.mbarrier_init.release.cluster; \n"
"}"
::);
#else
asm volatile ("brkpt;\n" ::);
#endif
}
// Issue a shared memory fence for async operations
CUTLASS_DEVICE
void fence_view_async_shared() {
#if CUDA_BARRIER_ENABLED
asm volatile (
"{\n\t"
"fence.proxy.async.shared::cta; \n"
"}"
::);
#else
asm volatile ("brkpt;\n" ::);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
} // end namespace arch
} // end namespace cutlass

View File

@ -36,6 +36,7 @@
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cute/arch/util.hpp"
namespace cutlass {
namespace arch {
@ -65,74 +66,13 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
#define CUDA_LDMATRIX_SUPPORTED 1
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1
#endif
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
#endif
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
#endif
*/
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern "C" {
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly. Rather, they should use the
// cutlass::arch::ldsm<>() template.
//
__device__ uint32_t __nvvm_get_smem_pointer(void *);
}
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
// the previous internal intrinsics if they are available.
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
// in inline PTX.
//
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
//
//__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTLASS helper to get SMEM pointer
return static_cast<unsigned>(__cvta_generic_to_shared(ptr));
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
return __nvvm_get_smem_pointer(ptr);
#elif defined(__CUDA_ARCH__)
uint32_t smem_ptr;
asm(
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_ptr) : "l"(ptr));
return smem_ptr;
#else
CUTLASS_UNUSED(ptr);
CUTLASS_NOT_IMPLEMENTED();
return 0;
#endif
return cute::cast_smem_ptr_to_uint(ptr);
}
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
return cutlass_get_smem_pointer(const_cast<void *>(ptr));

Some files were not shown because too many files have changed in this diff Show More