releaase 2.11 (#703)
This commit is contained in:
18
.github/labeler.yml
vendored
18
.github/labeler.yml
vendored
@ -1,18 +0,0 @@
|
||||
# https://github.com/actions/labeler#common-examples
|
||||
|
||||
examples:
|
||||
- examples/**
|
||||
|
||||
source:
|
||||
- cmake/**
|
||||
- include/cutlass/**
|
||||
|
||||
documentation:
|
||||
- docs/**
|
||||
- media/**
|
||||
|
||||
testing:
|
||||
- test/**
|
||||
|
||||
tooling:
|
||||
- tools/**
|
||||
3
.github/workflows/labeler.yml
vendored
3
.github/workflows/labeler.yml
vendored
@ -5,8 +5,7 @@ on:
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions: read-all|write-all
|
||||
steps:
|
||||
- uses: actions/labeler@master
|
||||
- uses: actions/labeler@main
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
27
CHANGELOG.md
27
CHANGELOG.md
@ -1,5 +1,27 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
|
||||
* [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
@ -16,11 +38,6 @@
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
|
||||
10
CITATION.cff
10
CITATION.cff
@ -73,10 +73,10 @@ abstract: >-
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.10.0/LICENSE.txt
|
||||
version: '2.10.0'
|
||||
date-released: '2022-09-15'
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt
|
||||
version: '2.11.0'
|
||||
date-released: '2022-11-19'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.10.0"
|
||||
description: The GitHub release URL of tag 2.10.0
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0"
|
||||
description: The GitHub release URL of tag 2.11.0
|
||||
|
||||
@ -38,7 +38,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.10.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.11.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
@ -87,6 +87,7 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
|
||||
@ -122,6 +123,9 @@ endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
@ -569,6 +573,9 @@ install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest)
|
||||
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "cuBLAS usage for tests")
|
||||
set(CUTLASS_ENABLE_CUDNN OFF CACHE BOOL "cuDNN usage for tests")
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake)
|
||||
|
||||
if (CUTLASS_ENABLE_CUBLAS)
|
||||
@ -732,7 +739,7 @@ if (CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
if (CUTLASS_ENABLE_PROFILER)
|
||||
add_dependencies(test_all test_profiler)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
if (CUTLASS_ENABLE_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
|
||||
@ -7,10 +7,10 @@
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Andrew Kerr
|
||||
Haicheng Wu
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Andrew Kerr
|
||||
Haicheng Wu
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Pradeep Ramani
|
||||
Cris Cecka
|
||||
Vijay Thakkar
|
||||
@ -20,52 +20,50 @@ Ethan Yan
|
||||
Zhaodong Chen
|
||||
Jack Kosaian
|
||||
Yujia Zhai
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
Manikandan Ananth
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
Manikandan Ananth
|
||||
|
||||
## CUTLASS Product Manager
|
||||
Matthew Nicely
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
Julien Demouth
|
||||
Brian Fahs
|
||||
Michael Goldfarb
|
||||
Mostafa Hagog
|
||||
Fei Hu
|
||||
Alan Kaatz
|
||||
Tina Li
|
||||
Timmy Liu
|
||||
Duane Merrill
|
||||
Kevin Siu
|
||||
Markus Tavenrath
|
||||
John Tran
|
||||
Vicki Wang
|
||||
Junkai Wu
|
||||
Fung Xie
|
||||
Albert Xu
|
||||
Jack Yang
|
||||
Xiuxia Zhang
|
||||
Nick Zhao
|
||||
Timothy Costa
|
||||
Julien Demouth
|
||||
Brian Fahs
|
||||
Michael Goldfarb
|
||||
Mostafa Hagog
|
||||
Fei Hu
|
||||
Alan Kaatz
|
||||
Tina Li
|
||||
Timmy Liu
|
||||
Duane Merrill
|
||||
Kevin Siu
|
||||
Markus Tavenrath
|
||||
John Tran
|
||||
Vicki Wang
|
||||
Junkai Wu
|
||||
Fung Xie
|
||||
Albert Xu
|
||||
Jack Yang
|
||||
Xiuxia Zhang
|
||||
Nick Zhao
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe
|
||||
Luke Durant
|
||||
Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
|
||||
Girish Bharambe
|
||||
Luke Durant
|
||||
Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
100
README.md
100
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.10
|
||||
# CUTLASS 2.11
|
||||
|
||||
_CUTLASS 2.10 - August 2022_
|
||||
_CUTLASS 2.11 - November 2022_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
@ -36,21 +36,21 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.10
|
||||
# What's New in CUTLASS 2.11
|
||||
|
||||
CUTLASS 2.10 is an update to CUTLASS adding:
|
||||
- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, Convolution and Grouped GEMM for different data types as well as different epilogue flavors.
|
||||
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable.
|
||||
- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax).
|
||||
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM.
|
||||
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after.
|
||||
- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing.
|
||||
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized.
|
||||
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now.
|
||||
- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
CUTLASS 2.11 is an update to CUTLASS adding:
|
||||
- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
|
||||
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
|
||||
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm).
|
||||
- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv).
|
||||
- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop).
|
||||
- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM.
|
||||
- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
@ -80,10 +80,11 @@ as shown in the above figure. Tensor Core operations are still implemented usin
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5.
|
||||
CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
|
||||
It is also compatible with CUDA 11.x.
|
||||
|
||||
## Operating Systems
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
@ -93,11 +94,12 @@ We have tested the following environments.
|
||||
| | Microsoft Visual Studio 2019|
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 21.04 | GCC 11.2.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
## Hardware
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
@ -110,9 +112,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|NVIDIA H100 PCIe|9.0|11.8|11.8|
|
||||
|
||||
# Documentation
|
||||
|
||||
@ -133,9 +133,16 @@ CUTLASS is described in the following documents and the accompanying
|
||||
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
|
||||
# Resources
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
CUTLASS is a header-only template library and does not need to be built to be used by other
|
||||
@ -199,6 +206,8 @@ include/ # client applications should target this directory
|
||||
|
||||
conv/ # code specialized for convolution
|
||||
|
||||
epilogue/ # code specialized for the epilogue of gemm/convolution
|
||||
|
||||
gemm/ # code specialized for general matrix product computations
|
||||
|
||||
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
|
||||
@ -206,6 +215,8 @@ include/ # client applications should target this directory
|
||||
platform/ # CUDA-capable Standard Library components
|
||||
|
||||
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
|
||||
|
||||
thread/ # simt code that can be performed within a CUDA thread
|
||||
|
||||
transform/ # code specialized for layout, type, and domain transformations
|
||||
|
||||
@ -216,49 +227,6 @@ include/ # client applications should target this directory
|
||||
|
||||
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
```
|
||||
examples/
|
||||
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
|
||||
|
||||
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
|
||||
|
||||
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
|
||||
|
||||
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
|
||||
|
||||
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
|
||||
|
||||
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
|
||||
|
||||
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
|
||||
|
||||
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores
|
||||
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
|
||||
31_basic_syrk # example demonstrating Symmetric Rank-K update
|
||||
|
||||
32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication
|
||||
|
||||
33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation
|
||||
|
||||
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
40_cutlass_py # example demonstrating CUTLASS with CUDA Python
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
```
|
||||
|
||||
@ -29,7 +29,6 @@
|
||||
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
03_visualize_layout
|
||||
@ -37,6 +36,5 @@ cutlass_example_add_executable(
|
||||
register_layout.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_COMMAND_00
|
||||
TEST_COMMAND_01
|
||||
)
|
||||
|
||||
|
||||
@ -64,15 +64,15 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
||||
// All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling
|
||||
// layout implementation with different templates.
|
||||
//
|
||||
// BMMA 88128 Interleaved-256
|
||||
// BMMA 168256 Interleaved-256
|
||||
// mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256
|
||||
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256
|
||||
{"TensorOpMultiplicand<1,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 256>>},
|
||||
// BMMA 88128 TN kblock512
|
||||
// BMMA 168256 TN kblock512
|
||||
// mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512
|
||||
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512
|
||||
{"TensorOpMultiplicand<1,512>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 512>>},
|
||||
// BMMA 168256 TN kblock1024
|
||||
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024
|
||||
{"TensorOpMultiplicand<1,1024>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 1024>>},
|
||||
// Integer matrix multiply.int4 8832 Interleaved-64
|
||||
|
||||
@ -81,7 +81,7 @@ matrix A can be seen as
|
||||
---------------------------------------
|
||||
batch 0 | batch 1
|
||||
, where batch size is 2, M is 6 and K is 2
|
||||
The stride (batch_stride_B) between the first element of two batches is lda * k
|
||||
The stride (batch_stride_A) between the first element of two batches is lda * k
|
||||
|
||||
matrix B can be seen as
|
||||
-----------------------------
|
||||
@ -94,7 +94,7 @@ matrix B can be seen as
|
||||
(1,1,0) | (1,1,1) | (1,1,2) |
|
||||
-----------------------------
|
||||
, where the batch size is 2, N is 3 and K is 2
|
||||
The stride (batch_stride_C) between the first element of two batches is k
|
||||
The stride (batch_stride_B) between the first element of two batches is k
|
||||
|
||||
|
||||
*/
|
||||
|
||||
@ -426,7 +426,7 @@ Result profile(Options const &options) {
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{
|
||||
typename Gemm::Arguments arguments(
|
||||
mode,
|
||||
options.problem_size,
|
||||
batch_count,
|
||||
@ -445,8 +445,7 @@ Result profile(Options const &options) {
|
||||
tensor_b.layout().stride(0),
|
||||
tensor_c.layout().stride(0),
|
||||
tensor_d.layout().stride(0),
|
||||
tensor_reduction.layout().stride(0)
|
||||
};
|
||||
tensor_reduction.layout().stride(0));
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
@ -515,15 +514,14 @@ Result profile(Options const &options) {
|
||||
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_nullptr_tensorref(nullptr, splitk_vector_layout);
|
||||
|
||||
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{
|
||||
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments(
|
||||
cutlass::MatrixCoord(1, reduce_vector_length),
|
||||
batch_count,
|
||||
size_t(reduce_vector_length),
|
||||
workspace_vector_tensorref,
|
||||
tensor_reduction_tensorref,
|
||||
tensor_nullptr_tensorref,
|
||||
{1.0f, 0.0f}
|
||||
};
|
||||
{1.0f, 0.0f});
|
||||
|
||||
ReduceVectorSplitK reduce_vector_splitk_op;
|
||||
|
||||
|
||||
@ -531,17 +531,17 @@ Result profile_convolution(Options const &options) {
|
||||
// Reduction input
|
||||
{
|
||||
reinterpret_cast<ElementAccumulator*> (workspace.get()),
|
||||
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
|
||||
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
||||
},
|
||||
// Destination
|
||||
{
|
||||
tensor_d.device_data(),
|
||||
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
|
||||
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
||||
},
|
||||
// Source
|
||||
{
|
||||
tensor_c.device_data(),
|
||||
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
|
||||
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
||||
},
|
||||
{options.alpha, options.beta}
|
||||
);
|
||||
|
||||
@ -367,12 +367,6 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
|
||||
@ -309,12 +309,6 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -690,7 +690,7 @@ public:
|
||||
// Initialize the GEMM object
|
||||
GemmBatched gemm;
|
||||
|
||||
result.status = gemm.initialize(arguments);
|
||||
result.status = gemm.initialize(arguments, nullptr);
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;
|
||||
@ -854,7 +854,7 @@ public:
|
||||
// Initialize the GEMM object
|
||||
GemmPermute gemm_normal;
|
||||
|
||||
result.status = gemm_normal.initialize(arguments);
|
||||
result.status = gemm_normal.initialize(arguments, nullptr);
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;
|
||||
|
||||
@ -1,230 +1,23 @@
|
||||
# CUTLASS Python Interface Example
|
||||
# CUTLASS Python Interface Examples
|
||||
This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples:
|
||||
* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations
|
||||
* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel
|
||||
>>>>>>> Add simplified examples
|
||||
|
||||
## Using Docker
|
||||
You can run the PyCUTLASS on NGC PyTorch container.
|
||||
## Setting up the Python interface
|
||||
Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API.
|
||||
|
||||
## Running examples
|
||||
Each of the basic examples can be run as follows:
|
||||
```shell
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3
|
||||
```
|
||||
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
|
||||
```bash
|
||||
apt-get update
|
||||
apt-get -y install libboost-all-dev
|
||||
# Run the GEMM example
|
||||
python gemm.py
|
||||
|
||||
# Run the Conv2d example
|
||||
python conv2d.py
|
||||
|
||||
# Run the grouped GEMM example
|
||||
python gemm_grouped.py
|
||||
```
|
||||
|
||||
|
||||
## Install the Python Interface
|
||||
The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables:
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
|
||||
|
||||
After setting these two environment variables, PyCUTLASS can be installed with
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
|
||||
```
|
||||
***
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue 1: permission denied
|
||||
Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission.
|
||||
|
||||
### Issue 2: rmm: module not found
|
||||
PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps:
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm
|
||||
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python
|
||||
python setup.py build_ext --inplace
|
||||
python setup.py install
|
||||
```
|
||||
To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues.
|
||||
|
||||
***
|
||||
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
|
||||
## GEMM Examples
|
||||
The GEMM examples use numpy to create input tensors and verify the results.
|
||||
### GEMM F64 Example
|
||||
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
|
||||
### GEMM F32 Example
|
||||
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
|
||||
```
|
||||
|
||||
### GEMM F16 Example
|
||||
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
|
||||
```
|
||||
|
||||
### GEMM BF16 Example
|
||||
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
|
||||
```
|
||||
|
||||
### GEMM Int8 Example
|
||||
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
|
||||
```python
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
```
|
||||
|
||||
### Batched & Array GEMM
|
||||
Example 1: Batched GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 2: Array GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
|
||||
```
|
||||
***
|
||||
## GEMM Grouped Examples
|
||||
The GEMM Grouped examples use numpy to create input tensors and verify the results.
|
||||
|
||||
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
```
|
||||
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
```
|
||||
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
***
|
||||
## Conv2d Example
|
||||
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences).
|
||||
### Conv2d F32 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
### Conv2d F32 Wgrad
|
||||
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
### Conv2d F32 Dgrad
|
||||
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
### Conv2d F16 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
## Epilogue
|
||||
### Bias
|
||||
To replace C with a bias vector, add `-bias` flag.
|
||||
### Activation function
|
||||
Example 1: ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
|
||||
```
|
||||
Example 2: leaky ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
|
||||
```
|
||||
Example 3: tanh (alpha=0 to avoid saturation)
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
|
||||
```
|
||||
Example 4: sigmoid
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
|
||||
```
|
||||
Example 5: SiLU
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
|
||||
```
|
||||
Example 6: HardSwish
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
|
||||
```
|
||||
Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
### Epilogue Visitor Tree
|
||||
Example 1:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2:
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 3:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 4:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 5:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 6:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
```
|
||||
To run the customizable examples, refer to the README in the [customizable](customizable) directory.
|
||||
|
||||
@ -29,290 +29,133 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
import torch.nn.functional as F
|
||||
"""
|
||||
Basic example of using the CUTLASS Python interface to run a 2d convolution
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorC32RSK32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'],
|
||||
help='Data type of computation in the epilogue')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
|
||||
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
|
||||
"StridedDgradHorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
# conv related
|
||||
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
|
||||
help="The type of convolution: forward propagation (fprop), \
|
||||
gradient of activation (dgrad), gradient of weight (wgrad)")
|
||||
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
|
||||
)
|
||||
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
|
||||
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
|
||||
help="This option describes iterator algorithm")
|
||||
|
||||
# arguments
|
||||
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
|
||||
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
|
||||
Parallel is used for parallel splitK.")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
|
||||
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
|
||||
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
|
||||
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
|
||||
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Launch a 2d convolution kernel from Python. "
|
||||
"See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation."))
|
||||
parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution")
|
||||
parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution")
|
||||
parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution")
|
||||
parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution")
|
||||
parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution")
|
||||
parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution")
|
||||
parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution")
|
||||
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70."
|
||||
|
||||
alignment = 1
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
# Allocate a pool of device memory to be used by the kernel
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
# Set the compiler to use to NVCC
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
[128, 128, 32], # Threadblock shape
|
||||
2, # Number of stages
|
||||
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
stride_support = getattr(StrideSupport, args.stride_support)
|
||||
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
|
||||
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
conv_kind=cutlass.conv.Operator.fprop,
|
||||
iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation,]
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
operations = [operation, ]
|
||||
|
||||
# Compile the operation
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
# Randomly initialize tensors
|
||||
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
|
||||
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
|
||||
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
|
||||
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
|
||||
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
|
||||
cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w),
|
||||
cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c),
|
||||
cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding
|
||||
cutlass.MatrixCoord(1, 1), # Strides
|
||||
cutlass.MatrixCoord(1, 1), # Dilation
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
args.split_k_slices, 1
|
||||
1, # Split k slices
|
||||
1 # Groups
|
||||
)
|
||||
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
|
||||
|
||||
# User-provide inputs
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
if args.bias:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
).at(3)
|
||||
else:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
|
||||
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
|
||||
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5))
|
||||
tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda")
|
||||
|
||||
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
if args.element_a != "int8":
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_b != "int8":
|
||||
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_c != "int8":
|
||||
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
|
||||
|
||||
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
|
||||
alpha = 1.
|
||||
beta = 0.
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_slices=problem_size.split_k_slices
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
|
||||
reduction_arguments = ReductionArguments(
|
||||
reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
partitions=problem_size.split_k_slices,
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
|
||||
# Run the operation
|
||||
operation.run(arguments)
|
||||
arguments.sync()
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
|
||||
|
||||
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
|
||||
if (args.activation_function != "identity"):
|
||||
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
|
||||
# Run the host reference module and compare to the CUTLASS result
|
||||
reference = Conv2dReferenceModule(A, B, C, operation.conv_kind)
|
||||
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
|
||||
|
||||
try:
|
||||
assert torch.equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
|
||||
|
||||
print("Passed.")
|
||||
|
||||
192
examples/40_cutlass_py/customizable/README.md
Normal file
192
examples/40_cutlass_py/customizable/README.md
Normal file
@ -0,0 +1,192 @@
|
||||
# Customizable Python Interface Examples
|
||||
This directory contains examples of using the CUTLASS Python interface with a variety of configurations for kernels.
|
||||
|
||||
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
|
||||
|
||||
## GEMM Examples
|
||||
The GEMM examples use numpy to create input tensors and verify the results.
|
||||
### GEMM F64 Example
|
||||
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
|
||||
### GEMM F32 Example
|
||||
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
|
||||
```
|
||||
|
||||
### GEMM F16 Example
|
||||
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
|
||||
```
|
||||
|
||||
### GEMM BF16 Example
|
||||
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
|
||||
```
|
||||
|
||||
### GEMM Int8 Example
|
||||
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
|
||||
```python
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
```
|
||||
|
||||
### Batched & Array GEMM
|
||||
Example 1: Batched GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 2: Array GEMM
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
|
||||
```
|
||||
***
|
||||
## GEMM Grouped Examples
|
||||
The GEMM Grouped examples use numpy to create input tensors and verify the results.
|
||||
|
||||
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
```
|
||||
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
```
|
||||
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
***
|
||||
## Conv2d Example
|
||||
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/get-started/locally/).
|
||||
### Conv2d F32 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
### Conv2d F32 Wgrad
|
||||
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
### Conv2d F32 Dgrad
|
||||
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
### Conv2d F16 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
## Epilogue
|
||||
### Bias
|
||||
To replace C with a bias vector, add `-bias` flag.
|
||||
### Activation function
|
||||
Example 1: ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
|
||||
```
|
||||
Example 2: leaky ReLU
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
|
||||
```
|
||||
Example 3: tanh (alpha=0 to avoid saturation)
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
|
||||
```
|
||||
Example 4: sigmoid
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
|
||||
```
|
||||
Example 5: SiLU
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
|
||||
```
|
||||
Example 6: HardSwish
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
|
||||
```
|
||||
Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
### Epilogue Visitor Tree
|
||||
Example 1:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2:
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 3:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 4:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 5:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 6:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
```
|
||||
320
examples/40_cutlass_py/customizable/conv2d.py
Normal file
320
examples/40_cutlass_py/customizable/conv2d.py
Normal file
@ -0,0 +1,320 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
|
||||
import argparse
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorC32RSK32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'],
|
||||
help='Data type of computation in the epilogue')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
|
||||
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
|
||||
"StridedDgradHorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
# conv related
|
||||
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
|
||||
help="The type of convolution: forward propagation (fprop), \
|
||||
gradient of activation (dgrad), gradient of weight (wgrad)")
|
||||
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
|
||||
)
|
||||
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
|
||||
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
|
||||
help="This option describes iterator algorithm")
|
||||
|
||||
# arguments
|
||||
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
|
||||
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
|
||||
Parallel is used for parallel splitK.")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
|
||||
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
|
||||
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
|
||||
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
|
||||
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
stride_support = getattr(StrideSupport, args.stride_support)
|
||||
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation,]
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
|
||||
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
|
||||
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
|
||||
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
|
||||
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
args.split_k_slices, 1
|
||||
)
|
||||
|
||||
|
||||
# User-provide inputs
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
if args.bias:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
).at(3)
|
||||
else:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
if args.element_a != "int8":
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_b != "int8":
|
||||
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_c != "int8":
|
||||
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
|
||||
|
||||
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_slices=problem_size.split_k_slices
|
||||
)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
|
||||
reduction_arguments = ReductionArguments(
|
||||
reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
partitions=problem_size.split_k_slices,
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
|
||||
|
||||
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
|
||||
if (args.activation_function != "identity"):
|
||||
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
try:
|
||||
assert torch.equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
|
||||
print("Passed.")
|
||||
445
examples/40_cutlass_py/customizable/gemm.py
Normal file
445
examples/40_cutlass_py/customizable/gemm.py
Normal file
@ -0,0 +1,445 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import cutlass
|
||||
from bfloat16 import bfloat16
|
||||
import sys
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help="This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
parser.add_argument("-epv", "--epilogue_visitor", default=None,
|
||||
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
|
||||
# Argument
|
||||
parser.add_argument("-p", "--problem_size",
|
||||
default=[128, 128, 128], nargs=3, type=int,
|
||||
help="GEMM problem size M, N, K")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
|
||||
help="Scaling factor of A * B")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float,
|
||||
help="Scaling factor of C")
|
||||
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
|
||||
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
|
||||
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
|
||||
GemmSplitKParallel is used for parallel splitK")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
|
||||
visitor = args.epilogue_visitor is not None
|
||||
|
||||
if args.epilogue_visitor == "ColumnReduction":
|
||||
class ColumnReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + beta * c
|
||||
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
|
||||
return D, reduction
|
||||
epilogue_functor = ColumnReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "RowReduction":
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
class RowBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'row', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = alpha * T
|
||||
Z = relu.numpy(scale_T + beta * c)
|
||||
return Z, T
|
||||
epilogue_functor = RowBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
else:
|
||||
epilogue_functor = epilogue_functor
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
visitor=visitor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation, ]
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
# User-provide inputs
|
||||
|
||||
problem_size = cutlass.gemm.GemmCoord(
|
||||
args.problem_size[0], args.problem_size[1], args.problem_size[2])
|
||||
|
||||
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(
|
||||
low=-2, high=2,size=(tensor_a_size,)
|
||||
).astype(getattr(np, args.element_a))
|
||||
|
||||
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(
|
||||
low=-2, high=2, size=(tensor_b_size,)
|
||||
).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
tensor_c_size = args.batch * problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
tensor_c_size = args.batch * problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_D = np.zeros(
|
||||
shape=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
if args.epilogue_visitor == "RowReduction":
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnReduction":
|
||||
cta_m = args.threadblock_shape[0]
|
||||
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
else:
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=output_op,
|
||||
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
|
||||
split_k_slices=args.split_k_slices, batch=args.batch
|
||||
)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_arguments = ReductionArguments(
|
||||
operation=reduction_operation,
|
||||
problem_size=[problem_size.m(), problem_size.n()],
|
||||
partitions=args.split_k_slices, workspace=arguments.ptr_D,
|
||||
destination=tensor_D, source=tensor_C,
|
||||
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
# run the host reference module
|
||||
reference = ReferenceModule(A, B, C)
|
||||
tensor_D_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
|
||||
|
||||
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
tensor_D_ref, reduction_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
args.alpha, args.beta
|
||||
)
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
reduction_ref = reduction_ref.flatten()
|
||||
assert np.allclose(reduction_ref, reduction, atol=1e-2)
|
||||
|
||||
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
|
||||
tensor_D_ref, tensor_T_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
vector, args.alpha, args.beta)
|
||||
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
tensor_T_ref = tensor_T_ref.flatten()
|
||||
|
||||
assert np.array_equal(tensor_t, tensor_T_ref)
|
||||
|
||||
try:
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
|
||||
print("Passed.")
|
||||
287
examples/40_cutlass_py/customizable/gemm_grouped.py
Normal file
287
examples/40_cutlass_py/customizable/gemm_grouped.py
Normal file
@ -0,0 +1,287 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import csv
|
||||
import sys
|
||||
|
||||
import argparse
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch CUTLASS GEMM Grouped kernels from Python")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU. \
|
||||
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
|
||||
This parameter is passed in at present to match the APIs of other kernels. The parameter \
|
||||
is unused within the kernel")
|
||||
# precompute mode
|
||||
parser.add_argument("-pm", "--precompute_mode",
|
||||
default="Device", type=str, choices=["Host", "Device"],
|
||||
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
|
||||
# arguments
|
||||
parser.add_argument("-p", "--problem_size_dir", type=str,
|
||||
help="path to the csv file contains the problem sizes")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if args.activation_function == "identity":
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
pycutlass.compiler.add_module([operation, ])
|
||||
|
||||
reference_module = ReferenceModule(A, B, C)
|
||||
|
||||
# get problems
|
||||
problem_sizes = []
|
||||
with open(args.problem_size_dir) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
problem_sizes.append(
|
||||
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
|
||||
)
|
||||
|
||||
problem_count = len(problem_sizes)
|
||||
|
||||
tensor_As = []
|
||||
tensor_Bs = []
|
||||
tensor_Cs = []
|
||||
tensor_Ds = []
|
||||
problem_sizes_coord = []
|
||||
tensor_D_refs = []
|
||||
|
||||
for problem_size in problem_sizes:
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.k(),)).astype(getattr(np, args.element_a))
|
||||
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
c_size = problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
c_size = problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
c_size = problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_D = np.zeros(
|
||||
shape=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_As.append(tensor_A)
|
||||
tensor_Bs.append(tensor_B)
|
||||
tensor_Cs.append(tensor_C)
|
||||
tensor_Ds.append(tensor_D)
|
||||
tensor_D_ref = reference_module.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size,
|
||||
args.alpha, args.beta, args.bias)
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
tensor_D_refs.append(tensor_D_ref)
|
||||
problem_sizes_coord.append(problem_size)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
|
||||
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
arguments.sync()
|
||||
|
||||
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
|
||||
try:
|
||||
assert np.array_equal(tensor_d, tensor_d_ref)
|
||||
except:
|
||||
assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5)
|
||||
|
||||
print("Passed.")
|
||||
@ -29,417 +29,110 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import cutlass
|
||||
from bfloat16 import bfloat16
|
||||
"""
|
||||
Basic example of using the CUTLASS Python interface to run a GEMM
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help="This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
parser.add_argument("-epv", "--epilogue_visitor", default=None,
|
||||
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
|
||||
# Argument
|
||||
parser.add_argument("-p", "--problem_size",
|
||||
default=[128, 128, 128], nargs=3, type=int,
|
||||
help="GEMM problem size M, N, K")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
|
||||
help="Scaling factor of A * B")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float,
|
||||
help="Scaling factor of C")
|
||||
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
|
||||
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
|
||||
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
|
||||
GemmSplitKParallel is used for parallel splitK")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'")
|
||||
parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM")
|
||||
parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM")
|
||||
parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM")
|
||||
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70."
|
||||
|
||||
alignment = 8
|
||||
assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment)
|
||||
assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment)
|
||||
assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
# Allocate a pool of device memory to be used by the kernel
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
# Set the compiler to use to NVCC
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
[128, 128, 32], # Threadblock shape
|
||||
2, # Number of stages
|
||||
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
|
||||
visitor = args.epilogue_visitor is not None
|
||||
|
||||
if args.epilogue_visitor == "ColumnReduction":
|
||||
class ColumnReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + beta * c
|
||||
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
|
||||
return D, reduction
|
||||
epilogue_functor = ColumnReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "RowReduction":
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
class RowBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'row', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = alpha * T
|
||||
Z = relu.numpy(scale_T + beta * c)
|
||||
return Z, T
|
||||
epilogue_functor = RowBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
else:
|
||||
epilogue_functor = epilogue_functor
|
||||
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
visitor=visitor
|
||||
)
|
||||
epilogue_functor=epilogue_functor)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation, ]
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
if (args.activation_function == "identity"):
|
||||
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
|
||||
# Compile the operation
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
# User-provide inputs
|
||||
# Randomly initialize tensors
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16)
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16)
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32)
|
||||
tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32)
|
||||
|
||||
problem_size = cutlass.gemm.GemmCoord(
|
||||
args.problem_size[0], args.problem_size[1], args.problem_size[2])
|
||||
|
||||
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
|
||||
).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(
|
||||
low=-2, high=2,size=(tensor_a_size,)
|
||||
).astype(getattr(np, args.element_a))
|
||||
|
||||
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
|
||||
).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(
|
||||
low=-2, high=2, size=(tensor_b_size,)
|
||||
).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
tensor_c_size = args.batch * problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
tensor_c_size = args.batch * problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_D = np.zeros(
|
||||
shape=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
if args.epilogue_visitor == "RowReduction":
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnReduction":
|
||||
cta_m = args.threadblock_shape[0]
|
||||
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
else:
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k)
|
||||
alpha = 1.
|
||||
beta = 0.
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=output_op,
|
||||
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
|
||||
split_k_slices=args.split_k_slices, batch=args.batch
|
||||
)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_arguments = ReductionArguments(
|
||||
operation=reduction_operation,
|
||||
problem_size=[problem_size.m(), problem_size.n()],
|
||||
partitions=args.split_k_slices, workspace=arguments.ptr_D,
|
||||
destination=tensor_D, source=tensor_C,
|
||||
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
bias = arguments.bias
|
||||
)
|
||||
output_op=operation.epilogue_type(alpha, beta))
|
||||
|
||||
# Run the operation
|
||||
operation.run(arguments)
|
||||
arguments.sync()
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
# run the host reference module
|
||||
# Run the host reference module and compare to the CUTLASS result
|
||||
reference = ReferenceModule(A, B, C)
|
||||
tensor_D_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
|
||||
|
||||
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
tensor_D_ref, reduction_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
args.alpha, args.beta
|
||||
)
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
reduction_ref = reduction_ref.flatten()
|
||||
assert np.allclose(reduction_ref, reduction, atol=1e-2)
|
||||
|
||||
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
|
||||
tensor_D_ref, tensor_T_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
vector, args.alpha, args.beta)
|
||||
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
tensor_T_ref = tensor_T_ref.flatten()
|
||||
|
||||
assert np.array_equal(tensor_t, tensor_T_ref)
|
||||
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
|
||||
|
||||
try:
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
|
||||
|
||||
print("Passed.")
|
||||
|
||||
@ -29,253 +29,125 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import csv
|
||||
"""
|
||||
Basic example of using the CUTLASS Python interface to run a grouped GEMM
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch CUTLASS GEMM Grouped kernels from python")
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU. \
|
||||
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
|
||||
This parameter is passed in at present to match the APIs of other kernels. The parameter \
|
||||
is unused within the kernel")
|
||||
# precompute mode
|
||||
parser.add_argument("-pm", "--precompute_mode",
|
||||
default="Device", type=str, choices=["Host", "Device"],
|
||||
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
|
||||
# arguments
|
||||
parser.add_argument("-p", "--problem_size_dir", type=str,
|
||||
help="path to the csv file contains the problem sizes")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
|
||||
|
||||
# Activation function
|
||||
parser.add_argument("-activ", "--activation_function", default="identity",
|
||||
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
|
||||
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
|
||||
help="addition arguments for activation")
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python")
|
||||
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70."
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
# Allocate a pool of device memory to be used by the kernel
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
# Set the compiler to use to NVCC
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
alignment = 1
|
||||
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
[128, 128, 32], # Threadblock shape
|
||||
2, # Number of stages
|
||||
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
if args.activation_function == "identity":
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
|
||||
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
epilogue_functor=epilogue_functor,
|
||||
precompute_mode=SchedulerMode.Device)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
pycutlass.compiler.add_module([operation, ])
|
||||
operations = [operation, ]
|
||||
|
||||
reference_module = ReferenceModule(A, B, C)
|
||||
|
||||
# get problems
|
||||
problem_sizes = []
|
||||
with open(args.problem_size_dir) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
problem_sizes.append(
|
||||
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
|
||||
)
|
||||
# Compile the operation
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
# Initialize tensors for each problem in the group
|
||||
problem_sizes = [
|
||||
cutlass.gemm.GemmCoord(128, 128, 64),
|
||||
cutlass.gemm.GemmCoord(512, 256, 128)
|
||||
]
|
||||
problem_count = len(problem_sizes)
|
||||
|
||||
alpha = 1.
|
||||
beta = 0.
|
||||
|
||||
tensor_As = []
|
||||
tensor_Bs = []
|
||||
tensor_Cs = []
|
||||
tensor_Ds = []
|
||||
problem_sizes_coord = []
|
||||
tensor_D_refs = []
|
||||
|
||||
reference = ReferenceModule(A, B, C)
|
||||
|
||||
for problem_size in problem_sizes:
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.k(),)).astype(getattr(np, args.element_a))
|
||||
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.bias:
|
||||
if args.layout_c == "RowMajor":
|
||||
c_size = problem_size.n()
|
||||
elif args.layout_c == "ColumnMajor":
|
||||
c_size = problem_size.m()
|
||||
else:
|
||||
raise ValueError(args.layout_c)
|
||||
else:
|
||||
c_size = problem_size.m() * problem_size.n()
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
|
||||
).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(
|
||||
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_D = np.zeros(
|
||||
shape=(problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
# Randomly initialize tensors
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16)
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16)
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32)
|
||||
tensor_D = np.zeros(shape=(m * n,)).astype(np.float32)
|
||||
|
||||
tensor_As.append(tensor_A)
|
||||
tensor_Bs.append(tensor_B)
|
||||
tensor_Cs.append(tensor_C)
|
||||
tensor_Ds.append(tensor_D)
|
||||
tensor_D_ref = reference_module.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size,
|
||||
args.alpha, args.beta, args.bias)
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
# Run the reference GEMM
|
||||
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
|
||||
tensor_D_refs.append(tensor_D_ref)
|
||||
problem_sizes_coord.append(problem_size)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
|
||||
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
|
||||
output_op=operation.epilogue_type(alpha, beta)
|
||||
)
|
||||
|
||||
# Run the operation
|
||||
operation.run(arguments)
|
||||
|
||||
arguments.sync()
|
||||
|
||||
# Compare the CUTLASS result to the host reference result
|
||||
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
|
||||
try:
|
||||
assert np.array_equal(tensor_d, tensor_d_ref)
|
||||
|
||||
60
examples/40_cutlass_py/util.py
Normal file
60
examples/40_cutlass_py/util.py
Normal file
@ -0,0 +1,60 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for interacting with device
|
||||
"""
|
||||
|
||||
from cuda import cudart
|
||||
|
||||
|
||||
# Raises an exception if `result` returned an error. Otherwise returns the result.
|
||||
def check_cuda_errors(result: list):
|
||||
# `result` is of the format : (cudaError_t, result...)
|
||||
err = result[0]
|
||||
if err.value:
|
||||
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
|
||||
|
||||
if len(result) == 1:
|
||||
return None
|
||||
elif len(result) == 2:
|
||||
return result[1]
|
||||
else:
|
||||
return result[1:]
|
||||
|
||||
|
||||
# Returns the integer representation of the device compute capability
|
||||
def get_device_cc(device: int = 0):
|
||||
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
||||
major = str(deviceProp.major)
|
||||
minor = str(deviceProp.minor)
|
||||
return int(major + minor)
|
||||
44
examples/41_fused_multi_head_attention/CMakeLists.txt
Normal file
44
examples/41_fused_multi_head_attention/CMakeLists.txt
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_fused_multi_head_attention_fixed_seqlen
|
||||
fused_multihead_attention_fixed_seqlen.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_fused_multi_head_attention_variable_seqlen
|
||||
fused_multihead_attention_variable_seqlen.cu
|
||||
)
|
||||
|
||||
|
||||
add_custom_target(41_fused_multi_head_attention
|
||||
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
|
||||
41_fused_multi_head_attention_variable_seqlen
|
||||
)
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include <float.h>
|
||||
#include <stdio.h>
|
||||
284
examples/41_fused_multi_head_attention/default_fmha_grouped.h
Normal file
284
examples/41_fused_multi_head_attention/default_fmha_grouped.h
Normal file
@ -0,0 +1,284 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "fmha_grouped.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "find_default_mma.h"
|
||||
#include "attention_scaling_coefs_updater.h"
|
||||
#include "mma_from_smem.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
||||
typename ArchTag_,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration,
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
|
||||
>
|
||||
struct DefaultFMHAGrouped {
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = float;
|
||||
using output_t = scalar_t;
|
||||
|
||||
// Accumulator between 2 iterations
|
||||
// Using `accum_t` improves perf on f16 at the cost of
|
||||
// numerical errors
|
||||
using output_accum_t = accum_t;
|
||||
|
||||
using ArchTag = ArchTag_;
|
||||
static bool const kIsAligned = isAligned_;
|
||||
static int const kWarpSize = 32;
|
||||
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
|
||||
|
||||
struct MM0 {
|
||||
/*
|
||||
In this first matmul, we compute a block of `Q @ K.T`.
|
||||
While the calculation result is still hot in registers, we update
|
||||
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
||||
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
||||
operand A for the second matmul (see MM1)
|
||||
*/
|
||||
|
||||
using GemmType = gemm_kernel_utils::DefaultGemmType<ArchTag, scalar_t>;
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
|
||||
using ElementA = scalar_t;
|
||||
using ElementB = scalar_t;
|
||||
using ElementC = scalar_t;
|
||||
using ElementAccumulator = accum_t;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
static int const kAlignmentA =
|
||||
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
||||
static int const kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using InstructionShape = typename GemmType::InstructionShape;
|
||||
|
||||
static int const kStages = DefaultConfig::kStages;
|
||||
using Operator = typename GemmType::Operator;
|
||||
|
||||
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator
|
||||
>::DefaultMma;
|
||||
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
typename Mma::Operator::IteratorC,
|
||||
ElementAccumulator,
|
||||
kWarpSize>::Updater;
|
||||
|
||||
static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");
|
||||
|
||||
// Epilogue to store to shared-memory in a format that we can use later for
|
||||
// the second matmul
|
||||
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
||||
typename Mma::Operator::IteratorC,
|
||||
typename Mma::Operator,
|
||||
scalar_t,
|
||||
WarpShape,
|
||||
ThreadblockShape>;
|
||||
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
||||
};
|
||||
|
||||
struct MM1 {
|
||||
/*
|
||||
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
||||
normalized) and stored in shared memory
|
||||
*/
|
||||
|
||||
using GemmType = typename MM0::GemmType;
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
|
||||
using ElementA = scalar_t;
|
||||
using ElementB = scalar_t;
|
||||
using ElementC = output_accum_t;
|
||||
using ElementAccumulator = accum_t;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
static int const kAlignmentA = DefaultConfig::kAlignmentA;
|
||||
static int const kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
|
||||
using ThreadblockShape = typename MM0::ThreadblockShape;
|
||||
using WarpShape = typename MM0::WarpShape;
|
||||
using InstructionShape = typename MM0::InstructionShape;
|
||||
|
||||
using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp;
|
||||
|
||||
static int const kStages = DefaultConfig::kStages;
|
||||
using Operator = typename GemmType::Operator;
|
||||
|
||||
using ThreadblockSwizzle = void; // Swizzling is unused
|
||||
static bool const kSplitKSerial = false;
|
||||
|
||||
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator>;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static_assert(WarpCount::kCount == kNumWarpsPerBlock, "");
|
||||
|
||||
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
||||
using OutputTileIterator =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_t>;
|
||||
using OutputTileIteratorAccum =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using FMHAKernel = kernel::FMHAGrouped<
|
||||
MM0,
|
||||
MM1,
|
||||
scalar_t,
|
||||
accum_t,
|
||||
output_t,
|
||||
output_accum_t,
|
||||
kSingleValueIteration,
|
||||
GroupScheduleMode_
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
@ -1,8 +1,39 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Cutlass provides helper template functions to figure out the right
|
||||
datastructures to instanciate to run a GEMM with various parameters (see
|
||||
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
|
||||
instanciation priority rules, it will only create an MmaMultiStage with
|
||||
instantiation priority rules, it will only create an MmaMultiStage with
|
||||
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
|
||||
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
|
||||
so we just copy-pasted some code from `default_mma.h` and
|
||||
839
examples/41_fused_multi_head_attention/fmha_grouped.h
Normal file
839
examples/41_fused_multi_head_attention/fmha_grouped.h
Normal file
@ -0,0 +1,839 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped FMHA kernel
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
|
||||
#include "fmha_grouped_problem_visitor.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "epilogue_rescale_output.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename MM0_, ///! Structure for computing P = Q @ K
|
||||
typename MM1_, ///! Structure for computing O = P @ V
|
||||
typename scalar_t_,
|
||||
typename accum_t_,
|
||||
typename output_t_,
|
||||
typename output_accum_t_,
|
||||
bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
|
||||
>
|
||||
struct FMHAGrouped {
|
||||
public:
|
||||
using MM0 = MM0_;
|
||||
using MM1 = MM1_;
|
||||
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = accum_t_;
|
||||
using output_t = output_t_;
|
||||
using output_accum_t = output_accum_t_;
|
||||
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
// Parameters to satisfy BaseGrouped
|
||||
using ElementA = scalar_t;
|
||||
using ElementB = scalar_t;
|
||||
using ElementC = accum_t;
|
||||
using LayoutA = typename MM0::LayoutA;
|
||||
using LayoutB = typename MM0::ElementB;
|
||||
using LayoutC = typename MM1::ElementC;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static int const kAlignmentA = MM0::kAlignmentA;
|
||||
static int const kAlignmentB = MM0::kAlignmentB;
|
||||
static int const kAlignmentC = 1;
|
||||
using Mma = typename MM1::Mma;
|
||||
using EpilogueOutputOp = typename MM1::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = void;
|
||||
using Operator = typename MM1::Operator;
|
||||
using WarpShape = typename MM1::WarpShape;
|
||||
using InstructionShape = typename MM1::InstructionShape;
|
||||
|
||||
using ElementQ = scalar_t;
|
||||
using ElementK = scalar_t;
|
||||
using ElementP = accum_t;
|
||||
using ElementV = scalar_t;
|
||||
using ElementO = output_t;
|
||||
using ElementOAccum = output_accum_t;
|
||||
using ElementAccumulator = accum_t;
|
||||
|
||||
using LayoutQ = typename MM0::LayoutA;
|
||||
using LayoutK = typename MM0::LayoutB;
|
||||
using LayoutP = typename MM0::LayoutC;
|
||||
using LayoutV = typename MM1::LayoutB;
|
||||
using LayoutO = typename MM1::LayoutC;
|
||||
|
||||
static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<ElementV>::value == 16);
|
||||
|
||||
static int const kAlignmentQ = MM0::kAlignmentA;
|
||||
static int const kAlignmentK = MM0::kAlignmentB;
|
||||
static int const kAlignmentV = 1;
|
||||
|
||||
using ThreadblockShape = typename MM0::ThreadblockShape;
|
||||
|
||||
static int const kQueriesPerBlock = ThreadblockShape::kM;
|
||||
static int const kKeysPerBlock = ThreadblockShape::kN;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename MM1::WarpCount;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor = FMHAGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
kGroupScheduleMode,
|
||||
kThreadCount,
|
||||
kThreadCount>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord *problem_sizes0;
|
||||
GemmCoord *problem_sizes1;
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementV ** ptr_V;
|
||||
ElementO ** ptr_O;
|
||||
ElementOAccum ** ptr_O_accum;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutO::Stride::LongIndex *ldo;
|
||||
|
||||
// Whether causal masking is to be performed
|
||||
bool causal;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord *host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_O_accum(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
causal(false),
|
||||
host_problem_sizes(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord *problem_sizes0,
|
||||
GemmCoord *problem_sizes1,
|
||||
int problem_count,
|
||||
int threadblock_count,
|
||||
ElementQ ** ptr_Q,
|
||||
ElementK ** ptr_K,
|
||||
ElementP ** ptr_P,
|
||||
ElementV ** ptr_V,
|
||||
ElementO ** ptr_O,
|
||||
ElementOAccum ** ptr_O_accum,
|
||||
typename LayoutQ::Stride::LongIndex *ldq,
|
||||
typename LayoutK::Stride::LongIndex *ldk,
|
||||
typename LayoutP::Stride::LongIndex *ldp,
|
||||
typename LayoutV::Stride::LongIndex *ldv,
|
||||
typename LayoutO::Stride::LongIndex *ldo,
|
||||
bool causal,
|
||||
GemmCoord *host_problem_sizes=nullptr
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
ptr_Q(ptr_Q),
|
||||
ptr_K(ptr_K),
|
||||
ptr_P(ptr_P),
|
||||
ptr_V(ptr_V),
|
||||
ptr_O(ptr_O),
|
||||
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O),
|
||||
ldq(ldq),
|
||||
ldk(ldk),
|
||||
ldv(ldv),
|
||||
ldo(ldo),
|
||||
causal(causal),
|
||||
host_problem_sizes(host_problem_sizes)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
bool __host__ check_supported() {
|
||||
CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ);
|
||||
CHECK_ALIGNED_PTR(ptr_K, kAlignmentK);
|
||||
CHECK_ALIGNED_PTR(ptr_V, kAlignmentV);
|
||||
XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned");
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementV ** ptr_V;
|
||||
ElementO ** ptr_O;
|
||||
ElementOAccum ** ptr_O_accum;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutO::Stride::LongIndex *ldo;
|
||||
|
||||
bool causal;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_O_accum(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
causal(false)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int tile_count = 0):
|
||||
problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
ptr_Q(args.ptr_Q),
|
||||
ptr_K(args.ptr_K),
|
||||
ptr_P(args.ptr_P),
|
||||
ptr_V(args.ptr_V),
|
||||
ptr_O(args.ptr_O),
|
||||
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O),
|
||||
ldq(args.ldq),
|
||||
ldk(args.ldk),
|
||||
ldv(args.ldv),
|
||||
ldo(args.ldo),
|
||||
causal(args.causal)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int tile_count = 0) {
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0,
|
||||
args.problem_sizes1,
|
||||
args.problem_count,
|
||||
workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
ptr_Q = args.ptr_Q;
|
||||
ptr_K = args.ptr_K;
|
||||
ptr_P = args.ptr_P;
|
||||
ptr_V = args.ptr_V;
|
||||
ptr_O = args.ptr_O;
|
||||
ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O;
|
||||
ldq = args.ldq;
|
||||
ldk = args.ldk;
|
||||
ldv = args.ldv;
|
||||
ldo = args.ldo;
|
||||
causal = args.causal;
|
||||
}
|
||||
};
|
||||
|
||||
// Shared storage - depends on kernel params
|
||||
struct ScalingCoefs {
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return epilogue;
|
||||
}
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return after_mm0.epilogue;
|
||||
}
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
using SharedStorage = typename cutlass::platform::conditional<
|
||||
kKeepOutputInRF,
|
||||
SharedStorageEpilogueAtEnd,
|
||||
SharedStorageEpilogueInLoop>::type;
|
||||
|
||||
private:
|
||||
|
||||
// Parameters to be used by an individual tile
|
||||
struct TileParams {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int query_start(int threadblock_idx) {
|
||||
return threadblock_idx * kQueriesPerBlock;
|
||||
}
|
||||
|
||||
// Returns whether this threadblock computes within the number of queries,
|
||||
// which is determined by the M dimension of problem 0
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) {
|
||||
return query_start(threadblock_idx) < problem_size0.m();
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) {
|
||||
return problem_size0.m() - query_start(threadblock_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) {
|
||||
int nk = problem_size0.n();
|
||||
if (causal) {
|
||||
nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk);
|
||||
}
|
||||
return nk;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
FMHAGrouped() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int16_t thread_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t warp_id() {
|
||||
return threadIdx.x / kThreadsPerWarp;
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
return threadIdx.x % kThreadsPerWarp;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor,
|
||||
shared_storage.problem_visitor,
|
||||
blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
|
||||
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
||||
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
||||
const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
if (!TileParams::can_compute(threadblock_idx, problem_size0)) {
|
||||
problem_visitor.advance(gridDim.x);
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t problem_idx = problem_visitor.problem_index();
|
||||
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = ElementAccumulator(0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
}
|
||||
|
||||
ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
|
||||
ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
|
||||
const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0);
|
||||
|
||||
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
||||
using OutputTileIterator = typename MM1::OutputTileIterator;
|
||||
return OutputTileIterator(
|
||||
typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]},
|
||||
ptr_O,
|
||||
typename OutputTileIterator::TensorCoord{
|
||||
num_queries, problem_size1.n()},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
auto createOutputAccumIter = [&](int col) ->
|
||||
typename MM1::OutputTileIteratorAccum {
|
||||
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
||||
return OutputTileIteratorAccum(
|
||||
typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]},
|
||||
ptr_O_accum,
|
||||
typename OutputTileIteratorAccum::TensorCoord{
|
||||
num_queries, problem_size1.n()},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
typename MM1::Mma::FragmentC accum_o;
|
||||
accum_o.clear();
|
||||
|
||||
const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal);
|
||||
|
||||
for (int32_t iter_key_start = 0; iter_key_start < num_keys;
|
||||
iter_key_start += kKeysPerBlock) {
|
||||
int32_t problem_size_0_m =
|
||||
cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries);
|
||||
int32_t problem_size_0_n = cutlass::fast_min(
|
||||
(int32_t)kKeysPerBlock, num_keys - iter_key_start);
|
||||
int32_t const& problem_size_0_k = problem_size0.k();
|
||||
int32_t const& problem_size_1_n = problem_size1.n();
|
||||
int32_t const& problem_size_1_k = problem_size_0_n;
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
};
|
||||
|
||||
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
||||
// updated from end of prev iter
|
||||
|
||||
//
|
||||
// MATMUL: Q.K_t
|
||||
//
|
||||
// Computes the block-matrix product of:
|
||||
// (a) query[query_start:query_end, :]
|
||||
// with
|
||||
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
||||
// and stores that into `shared_storage.si`
|
||||
//
|
||||
|
||||
ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx];
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename MM0::IteratorA iterator_A(
|
||||
typename MM0::IteratorA::Params(
|
||||
typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])),
|
||||
ptr_Q,
|
||||
{problem_size_0_m, problem_size_0_k},
|
||||
thread_id(),
|
||||
{0, 0});
|
||||
|
||||
typename MM0::IteratorB iterator_B(
|
||||
typename MM0::IteratorB::Params(
|
||||
typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])),
|
||||
params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx],
|
||||
{problem_size_0_k, problem_size_0_n},
|
||||
thread_id(),
|
||||
{0, 0});
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
typename MM0::Mma mma(
|
||||
shared_storage.mm0, thread_id(), warp_id(), lane_id());
|
||||
|
||||
typename MM0::Mma::FragmentC accum;
|
||||
|
||||
accum.clear();
|
||||
|
||||
auto gemm_k_iterations =
|
||||
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
iteratorC_tile_offset = {
|
||||
(warp_id() % MM0::Mma::WarpCount::kM),
|
||||
(warp_id() / MM0::Mma::WarpCount::kM)
|
||||
};
|
||||
|
||||
// Mask out last if causal
|
||||
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
accum[idx] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(problem_size0.k())));
|
||||
}));
|
||||
}));
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = warp_id() %
|
||||
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
||||
auto output_tile_coords = cutlass::MatrixCoord{
|
||||
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
||||
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
||||
|
||||
MM0::B2bGemm::accumToSmem(
|
||||
shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// MATMUL: Attn . V
|
||||
// Run the matmul `attn @ V` for a block of attn and V.
|
||||
// `attn` is read from shared memory (in `shared_storage_si`)
|
||||
// `V` is read from global memory (with iterator_B)
|
||||
//
|
||||
|
||||
const int64_t nBlockN = kKeepOutputInRF ? 1
|
||||
: ceil_div(
|
||||
(int64_t)problem_size_1_n,
|
||||
int64_t(MM1::ThreadblockShape::kN));
|
||||
|
||||
// Iterate over the N dimension of GEMM1
|
||||
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
||||
int gemm_k_iterations =
|
||||
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
||||
// (in registers)
|
||||
if (!kPreloadV) {
|
||||
__syncthreads(); // we share shmem between mma and epilogue
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
|
||||
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
}
|
||||
|
||||
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) {
|
||||
prologueV(blockN + 1);
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
(iter_key_start + kKeysPerBlock) >= num_keys,
|
||||
kIsLast,
|
||||
([&] {
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
output_accum_t,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // Read
|
||||
// iterator
|
||||
>;
|
||||
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = gemm_kernel_utils::call_conditional<
|
||||
kIsLast,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
if (!kKeepOutputInRF) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // we modify `m_prime` after
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
const bool kIsFirst = true;
|
||||
const bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp =
|
||||
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
||||
output_t, // output
|
||||
output_accum_t, // source
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator, // accum
|
||||
output_accum_t, // compute
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename MM1::OutputTileIterator, // destination
|
||||
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,178 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped FMHA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
// Helper for correctly representing problem sizes in grouped kernels
|
||||
template <typename ThreadblockShape>
|
||||
struct FMHAGroupedProblemSizeHelper {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) {
|
||||
// FMHA only partitions tiles across the M dimension.
|
||||
return cutlass::gemm::GemmCoord(
|
||||
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) {
|
||||
return grid.m() * grid.n();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
|
||||
using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>;
|
||||
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using BaseParams = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
int32_t problem_count;
|
||||
void const *workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
||||
problem_count(0), workspace(nullptr), tile_count(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1,
|
||||
int32_t problem_count,
|
||||
void const *workspace = nullptr,
|
||||
int32_t tile_count = 0
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count)
|
||||
{}
|
||||
|
||||
/// Convert the FMHA-specific parameters to those used by the base class
|
||||
CUTLASS_HOST_DEVICE
|
||||
BaseParams to_base() const {
|
||||
return BaseParams(// Set problem_sizes as problem_sizes1 because these determine
|
||||
// shape of the final output of FMHA
|
||||
problem_sizes1,
|
||||
problem_count,
|
||||
workspace,
|
||||
tile_count);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
FMHAGroupedProblemVisitor(
|
||||
Params const ¶ms_,
|
||||
SharedStorage &shared_storage_,
|
||||
int32_t block_idx
|
||||
): Base (
|
||||
params_.to_base(),
|
||||
shared_storage_, block_idx),
|
||||
problem_sizes0(params_.problem_sizes0),
|
||||
problem_sizes1(params_.problem_sizes1)
|
||||
{}
|
||||
|
||||
/// Returns the problem size 0 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size0() const {
|
||||
GemmCoord problem = problem_sizes0[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
/// Returns the problem size 1 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size1() const {
|
||||
GemmCoord problem = problem_sizes1[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -77,21 +77,17 @@
|
||||
Examples:
|
||||
|
||||
# Run an attention example with default setup
|
||||
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention
|
||||
$ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen
|
||||
|
||||
# Run an attention example with custom setup
|
||||
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true
|
||||
$ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true
|
||||
|
||||
Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers).
|
||||
*/
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
@ -241,8 +237,8 @@ struct Options {
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// problems belonging to the same batch share the same seq len
|
||||
int m_real = seq_length; // (rand() % seq_length);
|
||||
int mkv_real = seq_length_kv; // (rand() % seq_length_kv);
|
||||
int m_real = seq_length;
|
||||
int mkv_real = seq_length_kv;
|
||||
int m = (m_real + alignment - 1) / alignment * alignment;
|
||||
int mkv = (mkv_real + alignment - 1) / alignment * alignment;
|
||||
int k0 = head_size;
|
||||
@ -260,7 +256,6 @@ struct Options {
|
||||
problem_sizes0_real.push_back(problem0_real);
|
||||
problem_sizes1_real.push_back(problem1_real);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -268,7 +263,7 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "42_fused_multi_head_attention\n\n"
|
||||
out << "41_fused_multi_head_attention_fixed_seqlen\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --head_number=<int> Head number in multi-head attention (default: --head_number=12)\n"
|
||||
@ -276,7 +271,7 @@ struct Options {
|
||||
<< " --head_size=<int> Head size in multi-head attention (default: --head_size=64)\n"
|
||||
<< " --head_size_v=<int> Head size in multi-head attention for V (default: --head_size_v=head_size)\n"
|
||||
<< " --seq_length=<int> Sequence length in multi-head attention for Q (default: --seq_length=1024)\n"
|
||||
<< " --seq_length_kv=<int> Sequence length in multi-head attention for K/V(default: --seq_length_kv=seq_length)\n"
|
||||
<< " --seq_length_kv=<int> Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n"
|
||||
<< " --use_mask=<bool> If true, performs padding-like masking in softmax.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --reference-check=<bool> If true, performs reference check.\n"
|
||||
@ -342,8 +337,7 @@ public:
|
||||
using ElementSoftmaxCompute = typename Attention::accum_t;
|
||||
|
||||
using LayoutQ = cutlass::layout::RowMajor;
|
||||
using LayoutK = cutlass::layout::RowMajor;
|
||||
using LayoutK_T = cutlass::layout::ColumnMajor; // transposed
|
||||
using LayoutK = cutlass::layout::ColumnMajor;
|
||||
using LayoutP = cutlass::layout::RowMajor;
|
||||
using LayoutV = cutlass::layout::RowMajor;
|
||||
using LayoutO = cutlass::layout::RowMajor;
|
||||
@ -516,7 +510,7 @@ private:
|
||||
auto problem1 = options.problem_sizes1.at(i);
|
||||
|
||||
ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0);
|
||||
ldk_host.at(i) = LayoutK::packed({problem0.n(), problem0.k()}).stride(0);
|
||||
ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0);
|
||||
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
|
||||
ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0);
|
||||
ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0);
|
||||
@ -541,7 +535,6 @@ private:
|
||||
total_elements_P += elements_P;
|
||||
total_elements_V += elements_V;
|
||||
total_elements_O += elements_O;
|
||||
|
||||
}
|
||||
|
||||
problem_sizes_device0.reset(problem_count());
|
||||
@ -641,7 +634,7 @@ private:
|
||||
float abs_diff = fabs(diff);
|
||||
float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f);
|
||||
float relative_diff = abs_diff / abs_ref;
|
||||
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) {
|
||||
if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) {
|
||||
printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
|
||||
return false;
|
||||
}
|
||||
@ -661,7 +654,7 @@ private:
|
||||
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i);
|
||||
|
||||
LayoutQ layout_Q(ldq_host.at(i));
|
||||
LayoutK_T layout_K(ldk_host.at(i));
|
||||
LayoutK layout_K(ldk_host.at(i));
|
||||
LayoutP layout_P(ldp_host.at(i));
|
||||
LayoutV layout_V(ldv_host.at(i));
|
||||
LayoutO layout_O(ldo_host.at(i));
|
||||
@ -673,7 +666,7 @@ private:
|
||||
MatrixCoord extent_O{problem1.m(), problem1.k()};
|
||||
|
||||
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
||||
cutlass::TensorView<ElementK, LayoutK_T> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||
cutlass::TensorView<ElementP, LayoutP> view_P(block_P.get() + offset_P.at(i), layout_P, extent_P);
|
||||
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
|
||||
|
||||
@ -686,7 +679,7 @@ private:
|
||||
// Reference GEMM
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementQ, LayoutQ,
|
||||
ElementK, LayoutK_T,
|
||||
ElementK, LayoutK,
|
||||
ElementP, LayoutP,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
@ -988,6 +981,40 @@ public:
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration
|
||||
>
|
||||
int run_attention(Options& options) {
|
||||
using Attention = AttentionKernel<
|
||||
cutlass::half_t, // scalar_t
|
||||
cutlass::arch::Sm80, // ArchTag
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration
|
||||
>;
|
||||
|
||||
//
|
||||
// Test and profile
|
||||
//
|
||||
|
||||
TestbedAttention<Attention> testbed(options);
|
||||
|
||||
Result result = testbed.profile_grouped();
|
||||
if (!result.passed) {
|
||||
std::cout << "Profiling CUTLASS attention has failed.\n";
|
||||
std::cout << "\nFailed\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "\nPassed\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
@ -1041,52 +1068,25 @@ int main(int argc, char const **args) {
|
||||
std::cerr << "--alignment=1 is the only supported value\n";
|
||||
return -2;
|
||||
}
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
|
||||
constexpr bool kIs64x64 = true;
|
||||
// Set grid size
|
||||
constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32;
|
||||
constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128;
|
||||
if (kIs64x64 && options.head_size_v > kKeysPerBlock) {
|
||||
std::cerr << "WARNING: you will get better performance with `kIs64x64=false`\n";
|
||||
// Determine kernel configuration based on head size.
|
||||
// If head size is less than or equal to 64, each block operates over 64 queries and
|
||||
// 64 keys, and parital results can be stored in the register file.
|
||||
// If head size is greater than 64, each block operates over 32 queries and 128 keys,
|
||||
// and partial results are stored in shared memory.
|
||||
if (options.head_size_v > 64) {
|
||||
static int const kQueriesPerBlock = 32;
|
||||
static int const kKeysPerBlock = 128;
|
||||
if (options.head_size_v <= kKeysPerBlock) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
} else {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
||||
}
|
||||
} else {
|
||||
static int const kQueriesPerBlock = 64;
|
||||
static int const kKeysPerBlock = 64;
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
}
|
||||
|
||||
constexpr bool kSingleValueIteration = true;
|
||||
if (kSingleValueIteration && options.head_size_v > kKeysPerBlock) {
|
||||
std::cerr << "ERROR : Use kSingleValueIteration to keep output in RF. " \
|
||||
"This requires to have `head_size <= kKeysPerBlock` " \
|
||||
"but head_size_v=" << options.head_size_v << " and kKeysPerBlock=" << kKeysPerBlock << "\n";
|
||||
return -2;
|
||||
}
|
||||
if (!kSingleValueIteration && options.head_size_v <= kKeysPerBlock) {
|
||||
std::cerr << "WARNING: you will get better performance with `kSingleValueIteration=true` (keeps the output in RF rather than GMEM)\n";
|
||||
}
|
||||
|
||||
using Attention = AttentionKernel<
|
||||
cutlass::half_t, // scalar_t
|
||||
ArchTag,
|
||||
true, // memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration
|
||||
>;
|
||||
|
||||
//
|
||||
// Test and profile
|
||||
//
|
||||
|
||||
TestbedAttention<Attention> testbed(options);
|
||||
|
||||
Result result = testbed.profile_grouped();
|
||||
if (!result.passed) {
|
||||
std::cout << "Profiling CUTLASS attention has failed.\n";
|
||||
std::cout << "\nFailed\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "\nPassed\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "custom_mma_multistage.h"
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "predicated_tile_access_iterator_residual_last.h"
|
||||
#include "predicated_tile_iterator_residual_last.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace transform {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename BaseIterator>
|
||||
struct MakeIteratorResidualLast;
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
int AccessSize,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
typename AccessType,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileAccessIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>;
|
||||
};
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
@ -1,3 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
@ -1,626 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Defines the FusedMultiHeadAttention Class
|
||||
|
||||
The class contains the following:
|
||||
1) GEMM0 with epilogue fusion,
|
||||
2) GEMM1 with mainloop fusion, and
|
||||
3) A lightweight full softmax reduction kernel.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h"
|
||||
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
|
||||
#include "gemm_grouped_with_softmax_visitor.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <
|
||||
typename ElementQ_,
|
||||
typename LayoutQ_,
|
||||
typename ElementK_,
|
||||
typename LayoutK_,
|
||||
typename ElementP_,
|
||||
typename LayoutP_,
|
||||
typename ElementCompute_,
|
||||
typename OperatorClass_,
|
||||
typename ArchTag_,
|
||||
typename ThreadblockShape0_,
|
||||
typename ThreadblockShape1_,
|
||||
typename WarpShape0_,
|
||||
typename WarpShape1_,
|
||||
typename InstructionShape_,
|
||||
int kStages0_,
|
||||
int kStages1_,
|
||||
bool UseMasking_ = false,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
|
||||
int Alignment = 128 / cutlass::sizeof_bits<ElementQ_>::value,
|
||||
typename ElementSoftmax_ = ElementP_
|
||||
>
|
||||
class FusedMultiHeadAttention {
|
||||
public:
|
||||
|
||||
using ElementQ = ElementQ_;
|
||||
using ElementK = ElementK_;
|
||||
using ElementP = ElementP_;
|
||||
using ElementV = ElementK;
|
||||
using ElementOutput = ElementP;
|
||||
using ElementAccumulator = ElementCompute_;
|
||||
|
||||
using LayoutQ = LayoutQ_;
|
||||
using LayoutK = LayoutK_;
|
||||
using LayoutP = LayoutP_;
|
||||
using LayoutV = LayoutK;
|
||||
using LayoutO = LayoutP;
|
||||
|
||||
using ElementNorm = cutlass::half_t;
|
||||
using ElementSum = cutlass::half_t;
|
||||
using ElementSoftmaxCompute = float;
|
||||
using LayoutNorm = cutlass::layout::RowMajor;
|
||||
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
|
||||
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using ThreadblockShape0 = ThreadblockShape0_;
|
||||
using WarpShape0 = WarpShape0_;
|
||||
|
||||
using ThreadblockShape1 = ThreadblockShape1_;
|
||||
using WarpShape1 = WarpShape1_;
|
||||
|
||||
static int const Stages0 = kStages0_;
|
||||
static int const Stages1 = kStages1_;
|
||||
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
|
||||
|
||||
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>;
|
||||
|
||||
using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OperatorClass, ArchTag, ElementQ, ElementK, ElementP,
|
||||
ElementAccumulator>::Operator;
|
||||
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutP, cutlass::layout::ColumnMajor>::value;
|
||||
|
||||
static bool const kUseMasking = UseMasking_;
|
||||
|
||||
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_;
|
||||
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_;
|
||||
|
||||
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
|
||||
ElementQ,
|
||||
LayoutQ,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
8,
|
||||
ElementK,
|
||||
LayoutK,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
8,
|
||||
LayoutP,
|
||||
kInternalTranspose
|
||||
>;
|
||||
|
||||
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
typename MapArguments::ElementA,
|
||||
typename MapArguments::LayoutA,
|
||||
MapArguments::kAlignmentA,
|
||||
typename MapArguments::ElementB,
|
||||
typename MapArguments::LayoutB,
|
||||
MapArguments::kAlignmentB,
|
||||
ElementP,
|
||||
typename MapArguments::LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
ThreadblockSwizzle,
|
||||
Stages0,
|
||||
true,
|
||||
Operator,
|
||||
cutlass::gemm::SharedMemoryClearOption::kNone
|
||||
>::GemmKernel;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
|
||||
ThreadblockShape0,
|
||||
DefaultGemmKernel::kThreadCount,
|
||||
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
|
||||
typename EpilogueOutputOp0::ElementCompute,
|
||||
ElementNorm,
|
||||
ElementSum,
|
||||
ElementSoftmaxCompute,
|
||||
EpilogueOutputOp0,
|
||||
kUseMasking
|
||||
>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
EpilogueVisitor,
|
||||
typename DefaultGemmKernel::Epilogue
|
||||
>::Epilogue;
|
||||
|
||||
using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
kGroupScheduleMode0,
|
||||
kInternalTranspose,
|
||||
kUseMasking
|
||||
>;
|
||||
|
||||
using GemmGrouped0 = cutlass::gemm::device::GemmGrouped<GemmKernel0>;
|
||||
|
||||
using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
||||
ElementNorm,
|
||||
ElementSum,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape,
|
||||
true
|
||||
>;
|
||||
|
||||
using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion<
|
||||
ElementP,
|
||||
LayoutP,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
128 / cutlass::sizeof_bits<ElementQ>::value,
|
||||
ElementV,
|
||||
LayoutV,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
128 / cutlass::sizeof_bits<ElementK>::value,
|
||||
ElementNorm,
|
||||
LayoutNorm,
|
||||
ElementOutput,
|
||||
LayoutO,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages1,
|
||||
kGroupScheduleMode1
|
||||
>::GemmKernel;
|
||||
|
||||
using GemmGrouped1 = cutlass::gemm::device::GemmGrouped<GemmKernel1>;
|
||||
|
||||
public:
|
||||
|
||||
/// Arguments class
|
||||
struct Arguments {
|
||||
cutlass::gemm::GemmCoord *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementP ** ptr_V;
|
||||
ElementP ** ptr_O;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
ElementP *block_P;
|
||||
ElementNorm *block_Norm;
|
||||
ElementSum *block_Sum;
|
||||
int64_t *offset_P;
|
||||
int64_t *offset_Norm_Device;
|
||||
int64_t *offset_Sum_Device;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldp;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutP::Stride::LongIndex *ldo;
|
||||
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host;
|
||||
|
||||
ElementAccumulator alpha0;
|
||||
ElementAccumulator alpha1;
|
||||
ElementAccumulator beta;
|
||||
|
||||
int head_number;
|
||||
int batch_size;
|
||||
int seq_length;
|
||||
|
||||
typename ApplyFinalReductionDevice::Arguments reduction;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
block_P(nullptr),
|
||||
block_Norm(nullptr),
|
||||
block_Sum(nullptr),
|
||||
offset_P(nullptr),
|
||||
offset_Norm_Device(nullptr),
|
||||
offset_Sum_Device(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldp(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
head_number(0),
|
||||
batch_size(0),
|
||||
seq_length(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
cutlass::gemm::GemmCoord *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord *problem_sizes1,
|
||||
int problem_count,
|
||||
int threadblock_count,
|
||||
ElementQ ** ptr_Q,
|
||||
ElementK ** ptr_K,
|
||||
ElementP ** ptr_P,
|
||||
ElementP ** ptr_V,
|
||||
ElementP ** ptr_O,
|
||||
ElementNorm **ptr_Max,
|
||||
ElementSum **ptr_Sum,
|
||||
ElementP *block_P,
|
||||
ElementNorm *block_Norm,
|
||||
ElementSum *block_Sum,
|
||||
int64_t *offset_P,
|
||||
int64_t *offset_Norm_Device,
|
||||
int64_t *offset_Sum_Device,
|
||||
typename LayoutQ::Stride::LongIndex *ldq,
|
||||
typename LayoutK::Stride::LongIndex *ldk,
|
||||
typename LayoutP::Stride::LongIndex *ldp,
|
||||
typename LayoutP::Stride::LongIndex *ldv,
|
||||
typename LayoutP::Stride::LongIndex *ldo,
|
||||
ElementAccumulator alpha0,
|
||||
ElementAccumulator alpha1,
|
||||
ElementAccumulator beta,
|
||||
int head_number,
|
||||
int batch_size,
|
||||
int seq_length,
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr,
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr,
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
ptr_Q(ptr_Q),
|
||||
ptr_K(ptr_K),
|
||||
ptr_P(ptr_P),
|
||||
ptr_V(ptr_V),
|
||||
ptr_O(ptr_O),
|
||||
ptr_Max(ptr_Max),
|
||||
ptr_Sum(ptr_Sum),
|
||||
block_P(block_P),
|
||||
block_Norm(block_Norm),
|
||||
block_Sum(block_Sum),
|
||||
offset_P(offset_P),
|
||||
offset_Norm_Device(offset_Norm_Device),
|
||||
offset_Sum_Device(offset_Sum_Device),
|
||||
ldq(ldq),
|
||||
ldk(ldk),
|
||||
ldp(ldp),
|
||||
ldv(ldv),
|
||||
ldo(ldo),
|
||||
alpha0(alpha0),
|
||||
alpha1(alpha1),
|
||||
beta(beta),
|
||||
head_number(head_number),
|
||||
batch_size(batch_size),
|
||||
seq_length(seq_length),
|
||||
problem_sizes0_host(problem_sizes0_host),
|
||||
problem_sizes1_host(problem_sizes1_host),
|
||||
problem_sizes0_real(problem_sizes0_real),
|
||||
reduction(
|
||||
problem_sizes0,
|
||||
block_Norm,
|
||||
block_Sum,
|
||||
offset_Norm_Device,
|
||||
offset_Sum_Device
|
||||
)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementP ** ptr_V;
|
||||
ElementP ** ptr_O;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
ElementP *block_P;
|
||||
ElementNorm *block_Norm;
|
||||
ElementSum *block_Sum;
|
||||
int64_t *offset_P;
|
||||
int64_t *offset_Norm_Device;
|
||||
int64_t *offset_Sum_Device;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldp;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutP::Stride::LongIndex *ldo;
|
||||
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host;
|
||||
|
||||
ElementAccumulator alpha0;
|
||||
ElementAccumulator alpha1;
|
||||
ElementAccumulator beta;
|
||||
|
||||
int head_number;
|
||||
int batch_size;
|
||||
int seq_length;
|
||||
|
||||
typename ApplyFinalReductionDevice::Params reduction;
|
||||
|
||||
Params():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
block_P(nullptr),
|
||||
block_Norm(nullptr),
|
||||
block_Sum(nullptr),
|
||||
offset_P(nullptr),
|
||||
offset_Norm_Device(nullptr),
|
||||
offset_Sum_Device(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldp(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
problem_sizes0(nullptr),
|
||||
problem_sizes1(nullptr),
|
||||
problem_sizes0_real(nullptr),
|
||||
head_number(0),
|
||||
batch_size(0),
|
||||
seq_length(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Params(Arguments const &args, void *workspace = nullptr):
|
||||
problem_sizes0(args.problem_sizes0),
|
||||
problem_sizes1(args.problem_sizes1),
|
||||
problem_count(args.problem_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
ptr_Q(args.ptr_Q),
|
||||
ptr_K(args.ptr_K),
|
||||
ptr_P(args.ptr_P),
|
||||
ptr_V(args.ptr_V),
|
||||
ptr_O(args.ptr_O),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
block_P(args.block_P),
|
||||
block_Norm(args.block_Norm),
|
||||
block_Sum(args.block_Sum),
|
||||
offset_P(args.offset_P),
|
||||
offset_Norm_Device(args.offset_Norm_Device),
|
||||
offset_Sum_Device(args.offset_Sum_Device),
|
||||
ldq(args.ldq),
|
||||
ldk(args.ldk),
|
||||
ldp(args.ldp),
|
||||
ldv(args.ldv),
|
||||
ldo(args.ldo),
|
||||
problem_sizes0_host(args.problem_sizes0_host),
|
||||
problem_sizes1_host(args.problem_sizes1_host),
|
||||
problem_sizes0_real(args.problem_sizes0_real),
|
||||
alpha0(args.alpha0),
|
||||
alpha1(args.alpha1),
|
||||
beta(args.beta),
|
||||
head_number(args.head_number),
|
||||
batch_size(args.batch_size),
|
||||
seq_length(args.seq_length),
|
||||
reduction(args.reduction)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
Params params_;
|
||||
GemmGrouped0 gemm_grouped0;
|
||||
GemmGrouped1 gemm_grouped1;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Ctor
|
||||
FusedMultiHeadAttention() {
|
||||
|
||||
}
|
||||
|
||||
/// Initialize
|
||||
Status initialize(Arguments const &args,
|
||||
void *workspace0 = nullptr,
|
||||
void *workspace1 = nullptr) {
|
||||
|
||||
params_ = Params(args);
|
||||
|
||||
typename GemmGrouped0::Arguments args_gemm0(
|
||||
params_.problem_sizes0,
|
||||
params_.problem_count,
|
||||
params_.threadblock_count,
|
||||
params_.ptr_Q,
|
||||
params_.ptr_K,
|
||||
params_.ptr_P,
|
||||
params_.ptr_P,
|
||||
params_.ptr_Max,
|
||||
params_.ptr_Sum,
|
||||
params_.ldq,
|
||||
params_.ldk,
|
||||
params_.ldp,
|
||||
params_.ldp,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments(
|
||||
{
|
||||
params_.alpha0,
|
||||
params_.beta
|
||||
}
|
||||
),
|
||||
params_.problem_sizes0_host,
|
||||
params_.problem_sizes0_real
|
||||
);
|
||||
|
||||
|
||||
Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0);
|
||||
|
||||
typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta);
|
||||
|
||||
typename GemmGrouped1::Arguments args_gemm1(
|
||||
params_.problem_sizes1,
|
||||
params_.problem_count,
|
||||
params_.threadblock_count,
|
||||
epilogue_op1,
|
||||
params_.ptr_P,
|
||||
params_.ptr_V,
|
||||
params_.ptr_O,
|
||||
params_.ptr_O,
|
||||
(void**)params_.ptr_Max,
|
||||
(void**)params_.ptr_Sum,
|
||||
params_.ldp,
|
||||
params_.ldv,
|
||||
params_.ldo,
|
||||
params_.ldo,
|
||||
params_.problem_sizes1_host
|
||||
);
|
||||
|
||||
Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1);
|
||||
|
||||
if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}else{
|
||||
if (result0 != cutlass::Status::kSuccess) {
|
||||
return result0;
|
||||
}else{
|
||||
return result1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
Status result = gemm_grouped0.run();
|
||||
cudaError_t error_info;
|
||||
|
||||
if (result != cutlass::Status::kSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
int thread_per_block = 1024;
|
||||
|
||||
dim3 final_reduction_grid(params_.head_number * params_.batch_size);
|
||||
dim3 final_reduction_block(thread_per_block);
|
||||
|
||||
cutlass::Kernel<ApplyFinalReductionDevice><<<
|
||||
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream
|
||||
>>>(params_.reduction);
|
||||
|
||||
error_info = cudaGetLastError();
|
||||
|
||||
if (error_info != cudaSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = gemm_grouped1.run();
|
||||
|
||||
if (result != cutlass::Status::kSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Function call operator
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -1,522 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM kernel with epilogue visitor customized for softmax
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed_ = false,
|
||||
bool UseMask_ = false
|
||||
>
|
||||
struct GemmGroupedWithEpilogueVistor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor;
|
||||
static bool const kTransposed = Transposed_;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<
|
||||
typename Mma::IteratorA::Element,
|
||||
typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA,
|
||||
Mma::IteratorA::AccessType::kElements,
|
||||
typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout,
|
||||
Mma::kTransformB,
|
||||
Mma::IteratorB::AccessType::kElements,
|
||||
typename Mma::LayoutC,
|
||||
kTransposed
|
||||
>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementNorm = typename EpilogueVisitor::ElementNorm;
|
||||
using ElementSum = typename EpilogueVisitor::ElementSum;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor = GemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
kGroupScheduleMode,
|
||||
kThreadCount,
|
||||
kThreadCount,
|
||||
kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord *problem_sizes;
|
||||
// when using mask, real problem sizes may not be aligned
|
||||
// then we need to mask out unpadded elements in softmax
|
||||
GemmCoord *problem_sizes_real;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementA ** ptr_A;
|
||||
ElementB ** ptr_B;
|
||||
ElementC ** ptr_C;
|
||||
ElementC ** ptr_D;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
typename LayoutA::Stride::LongIndex *lda;
|
||||
typename LayoutB::Stride::LongIndex *ldb;
|
||||
typename LayoutC::Stride::LongIndex *ldc;
|
||||
typename LayoutC::Stride::LongIndex *ldd;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord *host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
lda(nullptr),
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr),
|
||||
host_problem_sizes(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord *problem_sizes,
|
||||
int problem_count,
|
||||
int threadblock_count,
|
||||
ElementA ** ptr_A,
|
||||
ElementB ** ptr_B,
|
||||
ElementC ** ptr_C,
|
||||
ElementC ** ptr_D,
|
||||
ElementNorm **ptr_Max,
|
||||
ElementSum **ptr_Sum,
|
||||
typename LayoutA::Stride::LongIndex *lda,
|
||||
typename LayoutB::Stride::LongIndex *ldb,
|
||||
typename LayoutC::Stride::LongIndex *ldc,
|
||||
typename LayoutC::Stride::LongIndex *ldd,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_,
|
||||
GemmCoord *host_problem_sizes=nullptr,
|
||||
GemmCoord *problem_sizes_real=nullptr
|
||||
):
|
||||
problem_sizes(problem_sizes),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
ptr_A(ptr_A),
|
||||
ptr_B(ptr_B),
|
||||
ptr_C(ptr_C),
|
||||
ptr_D(ptr_D),
|
||||
ptr_Max(ptr_Max),
|
||||
ptr_Sum(ptr_Sum),
|
||||
lda(lda),
|
||||
ldb(ldb),
|
||||
ldc(ldc),
|
||||
ldd(ldd),
|
||||
epilogue_visitor(epilogue_visitor_),
|
||||
host_problem_sizes(host_problem_sizes),
|
||||
problem_sizes_real(problem_sizes_real)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
GemmCoord *problem_sizes_real;
|
||||
int threadblock_count;
|
||||
|
||||
ElementA ** ptr_A;
|
||||
ElementB ** ptr_B;
|
||||
ElementC ** ptr_C;
|
||||
ElementC ** ptr_D;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
typename LayoutA::Stride::LongIndex *lda;
|
||||
typename LayoutB::Stride::LongIndex *ldb;
|
||||
typename LayoutC::Stride::LongIndex *ldc;
|
||||
typename LayoutC::Stride::LongIndex *ldd;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
lda(nullptr),
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr),
|
||||
problem_sizes_real(problem_sizes_real)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0):
|
||||
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
ptr_A(args.ptr_A),
|
||||
ptr_B(args.ptr_B),
|
||||
ptr_C(args.ptr_C),
|
||||
ptr_D(args.ptr_D),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
lda(args.lda),
|
||||
ldb(args.ldb),
|
||||
ldc(args.ldc),
|
||||
ldd(args.ldd),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
problem_sizes_real(args.problem_sizes_real)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int32_t tile_count = -1) {
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_Max = args.ptr_Max;
|
||||
ptr_Sum = args.ptr_Sum;
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
problem_sizes_real = args.problem_sizes_real;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
union {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmGroupedWithEpilogueVistor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor,
|
||||
shared_storage.problem_visitor,
|
||||
blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
|
||||
0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
0,
|
||||
threadblock_offset.n()
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A),
|
||||
ptr_A,
|
||||
{problem_size.m(), problem_size.k()},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B),
|
||||
ptr_B,
|
||||
{problem_size.k(), problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
ElementC *ptr_C = params.ptr_C[problem_idx];
|
||||
ElementC *ptr_D = params.ptr_D[problem_idx];
|
||||
|
||||
ElementNorm *ptr_Max = params.ptr_Max[problem_idx];
|
||||
ElementSum *ptr_Sum = params.ptr_Sum[problem_idx];
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m();
|
||||
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C);
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.kernel.epilogue.visitor,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
params_C,
|
||||
params_D,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
ptr_Max,
|
||||
ptr_Sum,
|
||||
threadblock_offset.mn(),
|
||||
column_offset,
|
||||
params.problem_sizes_real[problem_idx].mn()
|
||||
);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.kernel.epilogue.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -30,7 +30,7 @@
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
42_fused_multi_head_attention
|
||||
fused_multihead_attention.cu
|
||||
42_ampere_tensorop_group_conv
|
||||
ampere_tensorop_group_conv.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,706 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
This example shows how to run group convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
|
||||
|
||||
There are 2 group conv mode:
|
||||
1. cutlass::conv::GroupMode::kSingleGroup
|
||||
This mode is for large K problem size: k_per_group (K/groups) equals or larger than
|
||||
threadblock_tile_N. One or multiple threadblocks calculate data of one group.
|
||||
2. cutlass::conv::GroupMode::kMultipleGroup
|
||||
This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N.
|
||||
One threadblock will calculate data from more than one group.
|
||||
|
||||
Function profile_convolution_selecter() shows how to choose kernel with different group mode according
|
||||
to problem size and threadblock_tile size.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_group_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
// Analytic kernel and operation for single group problem size
|
||||
using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::GroupMode::kSingleGroup,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<AnalyticSingleGroupKernel>;
|
||||
|
||||
// Analytic kernel and operation for multiple group problem size
|
||||
using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::GroupMode::kMultipleGroup,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<AnalyticMultipleGroupKernel>;
|
||||
|
||||
// Optimized kernel and operation for single group problem size
|
||||
using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::GroupMode::kSingleGroup,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<OptimizedSingleGroupKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
int groups;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool optimized;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
groups(1),
|
||||
reference_check(false),
|
||||
measure_performance(false),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
optimized(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("optimized")) {
|
||||
optimized = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
|
||||
cmd.get_cmd_line_argument("g", groups);
|
||||
filter_size.c() = input_size.c() / groups;
|
||||
|
||||
cmd.get_cmd_line_argument("u", conv_stride.row());
|
||||
cmd.get_cmd_line_argument("v", conv_stride.column());
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "42_ampere_tensorop_group_conv example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward grouped convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --g=<int> Conv groups G\n\n"
|
||||
<< " --u=<int> Conv stride_h\n\n"
|
||||
<< " --v=<int> Conv stride_w\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n"
|
||||
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n"
|
||||
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< options.groups << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
template <typename Conv2dOperation>
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view());
|
||||
|
||||
// Fill tensor D for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices,
|
||||
options.groups
|
||||
);
|
||||
|
||||
// Construct Conv2dOperation::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename Conv2dOperation::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
Conv2dOperation implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on device...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::device::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_d.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Result profile_convolution_selecter(Options const &options) {
|
||||
int k_per_group = options.filter_size.n() / options.groups;
|
||||
|
||||
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
|
||||
if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode
|
||||
if (options.optimized) {
|
||||
std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl;
|
||||
exit(-1);
|
||||
} else {
|
||||
std::cout << "Select AnalyticMultipleGroupOperation\n";
|
||||
return profile_convolution<AnalyticMultipleGroupOperation>(options);
|
||||
}
|
||||
} else { // SingleGroup mode
|
||||
if (options.optimized) {
|
||||
std::cout << "Select OptimizedSingleGroupOperation\n";
|
||||
return profile_convolution<OptimizedSingleGroupOperation>(options);
|
||||
} else {
|
||||
std::cout << "Select AnalyticSingleGroupOperation\n";
|
||||
return profile_convolution<AnalyticSingleGroupOperation>(options);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution_selecter(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,66 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "predicated_tile_access_iterator_residual_last.h"
|
||||
#include "predicated_tile_iterator_residual_last.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace transform {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename BaseIterator>
|
||||
struct MakeIteratorResidualLast;
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
int AccessSize,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessSize,
|
||||
Gather>;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
int AdvanceRank,
|
||||
typename ThreadMap,
|
||||
typename AccessType,
|
||||
bool Gather>
|
||||
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>> {
|
||||
using Iterator = PredicatedTileAccessIteratorResidualLast<
|
||||
Shape,
|
||||
Element,
|
||||
Layout,
|
||||
AdvanceRank,
|
||||
ThreadMap,
|
||||
AccessType,
|
||||
Gather>;
|
||||
};
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
@ -28,9 +27,8 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_multi_head_attention
|
||||
fused_multihead_attention.cu
|
||||
43_ell_block_sparse_gemm
|
||||
ell_block_sparse_gemm.cu
|
||||
)
|
||||
|
||||
740
examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu
Normal file
740
examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu
Normal file
@ -0,0 +1,740 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Block-Ell sparse gemm example.
|
||||
|
||||
This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation.
|
||||
Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format.
|
||||
Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here:
|
||||
https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell
|
||||
Whereas matrix B is a dense matrix.
|
||||
|
||||
Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices.
|
||||
First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks,
|
||||
represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix),
|
||||
represented by tensor_ell_idx in this example, that represent the column indices of the
|
||||
corresponding non-zero blocks. All rows in the matrices must have the same number of blocks.
|
||||
ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in
|
||||
row-major order.
|
||||
|
||||
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
|
||||
for this example:
|
||||
a_rows - Rows in the sparse matrix.
|
||||
a_cols - Colums in the sparse matrix.
|
||||
a_ell_blocksize - Size of the ELL-Blocks.
|
||||
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
|
||||
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
|
||||
tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is
|
||||
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
|
||||
tensor_b - Input dense matrix whose size is (a_cols * n)
|
||||
tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n)
|
||||
{a_rows, n, a_cols} - Problem size
|
||||
|
||||
*/
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/ell_gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
int cuda_streams;
|
||||
int a_rows, n, a_cols;
|
||||
int a_ell_num_columns;
|
||||
int a_ell_blocksize;
|
||||
int a_base;
|
||||
float alpha;
|
||||
float beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
cuda_streams(0),
|
||||
a_rows(1024),
|
||||
n(1024),
|
||||
a_cols(1024),
|
||||
a_ell_num_columns(512),
|
||||
a_ell_blocksize(16),
|
||||
a_base(0),
|
||||
alpha(1),
|
||||
beta()
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.0f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 20);
|
||||
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
|
||||
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
||||
|
||||
cmd.get_cmd_line_argument("a_rows", a_rows, 1024);
|
||||
cmd.get_cmd_line_argument("n", n, 1024);
|
||||
cmd.get_cmd_line_argument("a_cols", a_cols, 1024);
|
||||
|
||||
cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512);
|
||||
cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16);
|
||||
cmd.get_cmd_line_argument("a_base", a_base, 0);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "43_ell_block_sparse_gemm\n\n"
|
||||
<< " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --a_rows=<int> Sets the number of the rows of the sparse matrix.\n"
|
||||
<< " --n=<int> Sets the N dimension.\n"
|
||||
<< " --a_cols=<int> Sets the number of columns of the sparse matrix.\n"
|
||||
<< " --a_ell_num_columns=<int> Sets the actual number of columns of the Blocked-Ellpack format.\n"
|
||||
<< " --a_ell_blocksize=<int> Sets the size of the ELL-Block.\n"
|
||||
<< " --a_base=<int> Sets the base index.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --reference-check=<bool> If true, performs reference check.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
|
||||
<< "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n"
|
||||
<< "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm>
|
||||
class Testbed {
|
||||
public:
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementAccumulator = typename Gemm::ElementAccumulator;
|
||||
|
||||
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using LayoutA = typename Gemm::LayoutA;
|
||||
using LayoutB = typename Gemm::LayoutB;
|
||||
using LayoutC = typename Gemm::LayoutC;
|
||||
|
||||
using MatrixCoord = typename LayoutC::TensorCoord;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_ELL;
|
||||
uint32_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_c;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_a_uncompressed;
|
||||
cutlass::HostTensor<ElementC, LayoutC> reference_d;
|
||||
|
||||
cutlass::HostTensor<int32_t, LayoutA> tensor_ell_idx;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Testbed(
|
||||
Options const &options_,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform,
|
||||
uint32_t seed_ = 3080
|
||||
):
|
||||
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { }
|
||||
|
||||
private:
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor_(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint32_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
view, seed, Element(), Element(0.5f));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
// Fill with increasing elements
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity(), Element(1), Element());
|
||||
} else {
|
||||
|
||||
// Fill with all 1s
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity(), Element(), Element(1));
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize_() {
|
||||
tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns));
|
||||
tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n));
|
||||
tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n));
|
||||
tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n));
|
||||
|
||||
tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols));
|
||||
reference_d.resize(cutlass::make_Coord(options.a_rows, options.n));
|
||||
|
||||
tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize,
|
||||
options.a_ell_num_columns / options.a_ell_blocksize));
|
||||
|
||||
//
|
||||
// Initialize the problems of the workspace
|
||||
//
|
||||
|
||||
initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021);
|
||||
initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022);
|
||||
initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023);
|
||||
|
||||
if (init_ELL == cutlass::Distribution::Uniform) {
|
||||
cutlass::reference::host::TensorFillRandomEllIdx(
|
||||
tensor_ell_idx.host_view(), seed,
|
||||
options.a_rows / options.a_ell_blocksize,
|
||||
options.a_ell_num_columns / options.a_ell_blocksize,
|
||||
options.a_cols / options.a_ell_blocksize);
|
||||
|
||||
} else {
|
||||
for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) {
|
||||
for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) {
|
||||
tensor_ell_idx.at({i, j}) = j+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ell_idx.sync_device();
|
||||
}
|
||||
|
||||
/// Verifies the result is a GEMM
|
||||
bool verify_() {
|
||||
|
||||
bool passed = true;
|
||||
|
||||
tensor_d.sync_host();
|
||||
|
||||
cutlass::uncompress_ell_block_sparse(
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_a.host_ref(),
|
||||
tensor_ell_idx.host_ref(),
|
||||
options.a_rows,
|
||||
options.a_cols,
|
||||
options.a_ell_num_columns,
|
||||
options.a_ell_blocksize
|
||||
);
|
||||
|
||||
cutlass::reference::host::Gemm<
|
||||
typename Gemm::ElementA, typename Gemm::LayoutA,
|
||||
typename Gemm::ElementB, typename Gemm::LayoutB,
|
||||
typename Gemm::ElementC, typename Gemm::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator, typename Gemm::Operator>
|
||||
reference_gemm;
|
||||
|
||||
reference_gemm(
|
||||
{options.a_rows, options.n, options.a_cols},
|
||||
options.alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
options.beta,
|
||||
reference_d.host_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
// Reference check
|
||||
passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_43_ell_block_sparse_gemm"
|
||||
<< "mnk_"
|
||||
<< options.a_rows << "x"
|
||||
<< options.n << "x"
|
||||
<< options.a_cols << "_"
|
||||
<< options.a_ell_num_columns << "_"
|
||||
<< options.a_ell_blocksize << ".txt";
|
||||
|
||||
std::cout << fname.str() << std::endl;
|
||||
|
||||
std::ofstream results(fname.str());
|
||||
|
||||
results
|
||||
<< "alpha: " << ElementCompute(options.alpha) << "\n"
|
||||
<< "beta: " << ElementCompute(options.beta) << "\n"
|
||||
<< "block size: " << options.a_ell_blocksize << "\n"
|
||||
<< "\nA:\n" << tensor_a.host_view() << "\n"
|
||||
<< "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n"
|
||||
<< "\nB:\n" << tensor_b.host_view() << "\n"
|
||||
<< "\nC:\n" << tensor_c.host_view() << "\n"
|
||||
<< "\nD reference:\n" << reference_d.host_view() << "\n"
|
||||
<< "\nD computed:\n" << tensor_d.host_view() << "\n";
|
||||
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns the number of threadblocks to launch if the kernel can run on the target
|
||||
/// device. Otherwise, returns zero.
|
||||
bool sufficient() const {
|
||||
//
|
||||
// Determine SMEM requirements and waive if not satisfied
|
||||
//
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
|
||||
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
throw std::runtime_error("cudaGetDevice() API call failed.");
|
||||
}
|
||||
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
||||
}
|
||||
|
||||
if (properties.sharedMemPerBlockOptin < smem_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Executes a BlockedEll SpMM kernel and measures runtime.
|
||||
Result profile() {
|
||||
|
||||
Result result;
|
||||
|
||||
// Early exit
|
||||
if (!sufficient()) {
|
||||
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
result.passed = false;
|
||||
|
||||
// Initialize the problem
|
||||
initialize_();
|
||||
|
||||
// Configure the GEMM arguments
|
||||
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
|
||||
|
||||
// Configure GEMM arguments
|
||||
typename Gemm::Arguments args(
|
||||
{options.a_rows, options.n, options.a_cols},
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
tensor_ell_idx.device_data(),
|
||||
options.a_ell_num_columns,
|
||||
options.a_ell_blocksize,
|
||||
options.a_base,
|
||||
epilogue_op
|
||||
);
|
||||
|
||||
// Initialize the GEMM object
|
||||
Gemm gemm;
|
||||
|
||||
result.status = gemm.initialize(args);
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Run the BlockedEll SpMM object
|
||||
result.status = gemm.run();
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for completion
|
||||
result.error = cudaDeviceSynchronize();
|
||||
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error);
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify correctness
|
||||
//
|
||||
result.passed = true;
|
||||
|
||||
if (options.reference_check) {
|
||||
result.passed = verify_();
|
||||
}
|
||||
|
||||
//
|
||||
// Warm-up run
|
||||
//
|
||||
result.status = gemm.run();
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMM operations
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
gemm();
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMM operations have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
//
|
||||
// Cleanup
|
||||
//
|
||||
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
std::cout << "ELL Block Sparse GEMM (CUTLASS):\n"
|
||||
<< "====================================================" << std::endl;
|
||||
|
||||
std::cout << std::endl;
|
||||
std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " " << " GFLOPs: " << result.gflops << std::endl;
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
|
||||
//
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
|
||||
|
||||
//
|
||||
// This example requires an NVIDIA Ampere-architecture GPU.
|
||||
//
|
||||
|
||||
std::cout
|
||||
<< "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or "
|
||||
<< "later (compute capability 80 or greater).\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Define the BlockedEll type
|
||||
//
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
constexpr int32_t kStages = 4;
|
||||
using Gemm = typename cutlass::gemm::device::EllGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementOutput,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
||||
kStages, kAlignmentA, kAlignmentB>;
|
||||
|
||||
//
|
||||
// Profile it
|
||||
//
|
||||
|
||||
Testbed<Gemm> testbed(options);
|
||||
|
||||
if (!testbed.sufficient()) {
|
||||
std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
Result result = testbed.profile();
|
||||
if (!result.passed) {
|
||||
std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n";
|
||||
std::cout << "\nFailed\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "\nPassed\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
63
examples/44_multi_gemm_ir_and_codegen/README.md
Normal file
63
examples/44_multi_gemm_ir_and_codegen/README.md
Normal file
@ -0,0 +1,63 @@
|
||||
This example provides utilities for generating back-to-back (B2B) GEMMs using CUTLASS.
|
||||
|
||||
## Quick start
|
||||
A configuration file containing the GEMMs to be fused together is located in [config.json](config.json). Edit
|
||||
this to change the configuration that you would like to run.
|
||||
```shell
|
||||
cd ir_gen
|
||||
|
||||
# Set up basic variables
|
||||
out_dir=directory_to_emit_files
|
||||
cutlass_dir=$(pwd)/../../..
|
||||
config_file=$(pwd)/../config.json
|
||||
|
||||
# Generate code for GEMMs described in `config_file`
|
||||
./generate.sh $config_file $out_dir $cutlass_dir
|
||||
|
||||
# Build the generated code
|
||||
cd $out_dir
|
||||
mkdir build && cd build
|
||||
cmake .. -DGPU_ARCHS="75;80"
|
||||
make -j
|
||||
|
||||
# Run the generated code with M=1024 K0=32 and Batch=1
|
||||
./sample 1024 32 1
|
||||
```
|
||||
|
||||
## Current restrictions
|
||||
This experimental example has the following restrictions:
|
||||
1. N tile should not exceed 256, or register spilling will occur.
|
||||
2. Only FP16 is supported currently
|
||||
3. Matrix A must be row major, matrix B must be column major, matrices C and D must be row major.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
32
examples/44_multi_gemm_ir_and_codegen/config.json
Normal file
32
examples/44_multi_gemm_ir_and_codegen/config.json
Normal file
@ -0,0 +1,32 @@
|
||||
{
|
||||
"0": {
|
||||
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
|
||||
"A_format": "Row", "B_format": "Col", "C_format": "Row",
|
||||
"mnk": [15000, 256, 32],
|
||||
"epilogue": {
|
||||
"tp": "LeakyRelu",
|
||||
"bias": {"addbias": false, "bias_tp": "mat"},
|
||||
"args": [["float", "leaky_alpha", 1.3]]
|
||||
}
|
||||
},
|
||||
"1": {
|
||||
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
|
||||
"A_format": "Row", "B_format": "Col", "C_format": "Row",
|
||||
"mnk": [15000, 128, 256],
|
||||
"epilogue": {
|
||||
"tp": "LeakyRelu",
|
||||
"bias": {"addbias": false, "bias_tp": "mat"},
|
||||
"args": [["float", "leaky_alpha", 1.3]]
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
|
||||
"A_format": "Row", "B_format": "Col", "C_format": "Row",
|
||||
"mnk": [15000, 64, 128],
|
||||
"epilogue": {
|
||||
"tp": "LeakyRelu",
|
||||
"bias": {"addbias": false, "bias_tp": "mat"},
|
||||
"args": [["float", "leaky_alpha", 1.3]]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,154 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
|
||||
// #include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
||||
#include "fused_bias_act_epilogue.h"
|
||||
#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h"
|
||||
#include "output_tile_thread_map_for_fused_bias.h"
|
||||
#include "default_thread_map_tensor_op_for_fused_bias.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultFusedBiasActEpilogueTensorOp {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC>,
|
||||
cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC> >::type;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
OutputOp
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,113 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines the optimal thread map for TensorOp accumulator layouts
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpShape_,
|
||||
int PartitionsK,
|
||||
typename Element_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultThreadMapTensorOpForFusedBias {
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using Element = Element_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
//
|
||||
// Definitions
|
||||
//
|
||||
|
||||
struct Detail {
|
||||
|
||||
/// Tensor Operations fundamentally perform operations on 8 rows
|
||||
static int const kTensorOpRows = 8;
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
ThreadblockShape::kM / WarpShape::kM,
|
||||
ThreadblockShape::kN / WarpShape::kN,
|
||||
kPartitionsK
|
||||
>;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
};
|
||||
|
||||
//
|
||||
// ThreadMap
|
||||
//
|
||||
|
||||
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
|
||||
using Type = OutputTileOptimalThreadMapBiasAct <
|
||||
OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
|
||||
OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
|
||||
Detail::kThreads,
|
||||
kElementsPerAccess,
|
||||
sizeof_bits<Element>::value
|
||||
>;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,222 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator without splitk
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename OutputOp_ ///< Output operator
|
||||
>
|
||||
class FusedBiasActEpilogue {
|
||||
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
/// Output layout is always row-major
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
FusedBiasActEpilogue(
|
||||
){ }
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile & fused_bias_act_accumlators,
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
bool need_bias = output_op.is_source_needed();
|
||||
|
||||
if (need_bias)
|
||||
compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator);
|
||||
else
|
||||
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
|
||||
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile & fused_bias_act_accumlators,
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
|
||||
source_fragment.clear();
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
|
||||
fused_bias_act_fragment = output_op(accum_fragment, source_fragment);
|
||||
|
||||
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
|
||||
++fused_bias_act_fragment_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_no_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
|
||||
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) {
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
|
||||
fused_bias_act_fragment = output_op(accum_fragment);
|
||||
|
||||
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
|
||||
++fused_bias_act_fragment_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,311 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// RowArrangement determines how one or more warps cover a region of consecutive rows.
|
||||
template <
|
||||
typename Shape,
|
||||
int WarpsRemaining,
|
||||
int ElementsPerAccess,
|
||||
int ElementSize,
|
||||
bool Is2dTile
|
||||
>
|
||||
struct RowArrangementBiasAct;
|
||||
|
||||
/// RowArrangement in which each warp's access is a 1D tiled arrangement.
|
||||
template <
|
||||
typename Shape,
|
||||
int WarpsRemaining,
|
||||
int ElementsPerAccess,
|
||||
int ElementSize
|
||||
>
|
||||
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
|
||||
static int const kWarpSize = 32;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
static int const kIterationsRow = 1;
|
||||
static int const kDeltaRow = 1;
|
||||
static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
|
||||
static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
|
||||
|
||||
static int const kAccessWidth = kWarpSize;
|
||||
static int const kAccessRows = 1;
|
||||
static int const kWarpPartitionsRow = 1;
|
||||
static int const kWarpPartitionsColumn = WarpsRemaining;
|
||||
};
|
||||
|
||||
/// RowArrangement in which each warp's access is a 2D tiled arrangement.
|
||||
template <
|
||||
typename Shape,
|
||||
int WarpsRemaining,
|
||||
int ElementsPerAccess,
|
||||
int ElementSize
|
||||
>
|
||||
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
|
||||
|
||||
static int const kMemoryAccessSize = 4;//128;
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
struct Detail {
|
||||
static int const kShapeRow = Shape::kRow / WarpsRemaining;
|
||||
static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
|
||||
|
||||
static int const kTargetMemoryAccessWidth =
|
||||
kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
|
||||
|
||||
static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
|
||||
};
|
||||
|
||||
static int const kAccessWidth =
|
||||
(Detail::kTargetAccessRows > Detail::kShapeRow ?
|
||||
kWarpSize / Detail::kShapeRow
|
||||
: const_min(
|
||||
Detail::kShapeWidth,
|
||||
const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
|
||||
));
|
||||
|
||||
static int const kAccessRows =
|
||||
(Detail::kTargetAccessRows > Detail::kShapeRow ?
|
||||
Detail::kShapeRow
|
||||
: const_min(Shape::kRow, kWarpSize / kAccessWidth));
|
||||
|
||||
static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
|
||||
static int const kDeltaRow = kAccessRows;
|
||||
|
||||
static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
|
||||
static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
|
||||
|
||||
static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
|
||||
static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
|
||||
static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
|
||||
|
||||
static int const kWarpPartitionsRow = 1;
|
||||
static int const kWarpPartitionsColumn = 1;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template metaprogram for partitioning a 4D space across warps to achieve several performance
|
||||
/// objectives:
|
||||
///
|
||||
/// - coalesced memory accesses in units of 16 Byte lines
|
||||
/// - minimal address arithmetic
|
||||
/// - minimal predicate calculations
|
||||
///
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Count_,
|
||||
int Threads,
|
||||
int ElementsPerAccess,
|
||||
int ElementSize
|
||||
>
|
||||
struct OutputTileOptimalThreadMapBiasAct {
|
||||
|
||||
using Shape = Shape_;
|
||||
using Count = Count_;
|
||||
|
||||
static int const kWarpSize = 32;
|
||||
static int const kThreads = Threads;
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
//
|
||||
// Metaprogram computation
|
||||
//
|
||||
|
||||
struct Detail {
|
||||
|
||||
// Clusters
|
||||
static int const kIterationsCluster =
|
||||
((Shape::kCluster > kWarpCount) ?
|
||||
Shape::kCluster / kWarpCount
|
||||
: 1);
|
||||
|
||||
static int const kDeltaCluster =
|
||||
((Shape::kCluster > kWarpCount) ?
|
||||
Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
|
||||
: 1);
|
||||
|
||||
static int const kCompactedDeltaCluster =
|
||||
((Shape::kCluster > kWarpCount) ?
|
||||
Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
|
||||
: 1);
|
||||
|
||||
static int const kWarpPartitionsCluster =
|
||||
((Shape::kCluster > kWarpCount) ?
|
||||
kWarpCount
|
||||
: kWarpCount / Shape::kCluster);
|
||||
|
||||
static int const kWarpsRemainingForGroups =
|
||||
((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
|
||||
|
||||
// Groups
|
||||
static int const kIterationsGroup =
|
||||
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
||||
Shape::kGroup / kWarpsRemainingForGroups
|
||||
: 1);
|
||||
|
||||
static int const kDeltaGroup =
|
||||
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
||||
Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
|
||||
: 1);
|
||||
|
||||
static int const kCompactedDeltaGroup =
|
||||
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
||||
Shape::kRow * Shape::kGroup / kIterationsGroup
|
||||
: 1);
|
||||
|
||||
static int const kWarpPartitionsGroup =
|
||||
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
||||
1
|
||||
: kWarpsRemainingForGroups / Shape::kGroup);
|
||||
|
||||
static int const kWarpsRemainingForRows =
|
||||
((Shape::kGroup > kWarpsRemainingForGroups) ?
|
||||
1
|
||||
: kWarpsRemainingForGroups / Shape::kGroup);
|
||||
|
||||
// Rows
|
||||
using RowArrangement = detail::RowArrangementBiasAct<
|
||||
Shape,
|
||||
kWarpsRemainingForRows,
|
||||
kElementsPerAccess,
|
||||
kElementSize,
|
||||
(Shape::kRow > kWarpsRemainingForRows)
|
||||
>;
|
||||
|
||||
// Warp partitions
|
||||
using WarpPartitions = OutputTileShape<
|
||||
RowArrangement::kWarpPartitionsColumn,
|
||||
RowArrangement::kWarpPartitionsRow,
|
||||
kWarpPartitionsGroup,
|
||||
kWarpPartitionsCluster,
|
||||
1>;
|
||||
|
||||
static int const kAccessWidth = RowArrangement::kAccessWidth;
|
||||
static int const kAccessRows = RowArrangement::kAccessRows;
|
||||
};
|
||||
|
||||
//
|
||||
// Output
|
||||
//
|
||||
|
||||
using Iterations = OutputTileShape<
|
||||
Detail::RowArrangement::kIterationsColumn,
|
||||
Detail::RowArrangement::kIterationsRow,
|
||||
Detail::kIterationsGroup,
|
||||
Detail::kIterationsCluster,
|
||||
1>;
|
||||
|
||||
using Delta = OutputTileShape<
|
||||
Detail::RowArrangement::kDeltaColumn,
|
||||
Detail::RowArrangement::kDeltaRow,
|
||||
Detail::kDeltaGroup,
|
||||
Detail::kDeltaCluster,
|
||||
1>;
|
||||
|
||||
/// Initial offset function
|
||||
CUTLASS_HOST_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
|
||||
int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
|
||||
|
||||
int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
|
||||
int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
|
||||
|
||||
int row_idx = residual_group / Detail::WarpPartitions::kRow;
|
||||
int col_idx = residual_group % Detail::WarpPartitions::kRow;
|
||||
|
||||
// Compute per-lane offset
|
||||
int lane_row_offset = lane_idx / Detail::kAccessWidth;
|
||||
int lane_col_offset = lane_idx % Detail::kAccessWidth;
|
||||
|
||||
// Compute coordinate in output space
|
||||
int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
|
||||
int group_offset = group_idx * Shape::kRow * Count::kRow;
|
||||
int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
|
||||
int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
|
||||
|
||||
return MatrixCoord(
|
||||
cluster_offset + group_offset + row_offset + lane_row_offset,
|
||||
(column_offset + lane_col_offset) * kElementsPerAccess
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@ -0,0 +1,189 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile
|
||||
that participate in one warp-level store operation.
|
||||
|
||||
Typically, the accumulator tile is the largest single block of register-backed storage
|
||||
within the kernel. Storing it to memory is best accomplished by partitioning it into
|
||||
smaller tiles and storing these sequentially.
|
||||
|
||||
Round trips through shared memory during the Epilogue phase require partitioning, as
|
||||
shared memory capacity is typically insufficient for a threadblock's total accumulator
|
||||
size.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/tensor_op_policy.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace warp {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///
|
||||
template <
|
||||
typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape)
|
||||
typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
||||
typename OperatorElementC, ///< matrix multiply operation data type (concept: data type)
|
||||
typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array)
|
||||
typename Layout ///< target shared memory layout
|
||||
>
|
||||
class FusedBiasActFragmentIteratorTensorOp;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for row-major shared memory
|
||||
template <
|
||||
typename WarpShape_, ///< shape of the warp-level GEMM tile
|
||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
||||
typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type)
|
||||
typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array)
|
||||
>
|
||||
class FusedBiasActFragmentIteratorTensorOp<WarpShape_, OperatorShape_, OperatorElementC_, OperatorFragmentC_, layout::RowMajor> {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
using OperatorShape = OperatorShape_;
|
||||
using OperatorElementC = OperatorElementC_;
|
||||
using OperatorFragmentC = OperatorFragmentC_;
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
|
||||
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = Array<
|
||||
OperatorElementC,
|
||||
Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
|
||||
|
||||
/// This is the complete warp-level accumulator tile.
|
||||
using AccumulatorTile = Array<
|
||||
OperatorElementC,
|
||||
OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
|
||||
|
||||
using OutputAccumulatorTile = AccumulatorTile;
|
||||
|
||||
/// Number of times this iterator can be incremented
|
||||
static int const kIterations = Policy::kIterations;
|
||||
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Accumulator tile
|
||||
AccessType *accumulators_;
|
||||
|
||||
/// Internal index
|
||||
int index_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs an iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum):
|
||||
accumulators_(reinterpret_cast<AccessType *>(&accum)),
|
||||
index_(0) {
|
||||
}
|
||||
|
||||
/// Increments
|
||||
CUTLASS_HOST_DEVICE
|
||||
FusedBiasActFragmentIteratorTensorOp &operator++() {
|
||||
++index_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Decrements
|
||||
CUTLASS_HOST_DEVICE
|
||||
FusedBiasActFragmentIteratorTensorOp &operator--() {
|
||||
--index_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from the referenced part of the accumulator tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag, int index_offset = 0) const {
|
||||
|
||||
int index = index_ + index_offset;
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
||||
|
||||
int accumulator_access_offset =
|
||||
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
|
||||
|
||||
frag_ptr[n] = accumulators_[accumulator_access_offset];
|
||||
}
|
||||
}
|
||||
/// Stores a fragment from the referenced part of the accumulator tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment &frag, int index_offset = 0) const {
|
||||
|
||||
int index = index_ + index_offset;
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
||||
|
||||
int accumulator_access_offset =
|
||||
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
|
||||
|
||||
accumulators_[accumulator_access_offset] = frag_ptr[n];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,427 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Size of the accumulation tile shape (concept: MatrixShape)
|
||||
typename AccumulatorShape_,
|
||||
/// KBlocks columns to compute residual
|
||||
int KBlocksColumn_,
|
||||
/// Accumulator Element type
|
||||
typename ElementAccumulator_,
|
||||
/// Element type
|
||||
typename Element_,
|
||||
/// Layout of operand in memory
|
||||
typename Layout_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Whether beta is zero
|
||||
bool IsBetaZero_ >
|
||||
class MmaTensorOpPureFragmentIterator;
|
||||
|
||||
|
||||
// Partial specialization for col-major accumulator tile
|
||||
// And Element type is the same as Accumulator Element type
|
||||
|
||||
template <
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
||||
typename AccumulatorShape_,
|
||||
/// KBlocks columns to compute residual
|
||||
int KBlocksColumn_,
|
||||
/// Element type
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_>
|
||||
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_, true> {
|
||||
public:
|
||||
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
||||
using AccumulatorShape = AccumulatorShape_;
|
||||
|
||||
/// KBlocks columns to compute residual
|
||||
static int const kKBlockColumn = KBlocksColumn_;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::ColumnMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Whether beta is zero
|
||||
static bool const IsBetaZero = true;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
!(Shape::kRow % InstructionShape::kM) &&
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
static_assert(
|
||||
!(AccumulatorShape::kRow % Shape::kRow) &&
|
||||
!(AccumulatorShape::kColumn % Shape::kColumn),
|
||||
"Shape of Warp Accumulator must be divisible by warp shape.");
|
||||
static_assert(
|
||||
!(kKBlockColumn % Shape::kColumn),
|
||||
"KBlock size must be divisible by warp shape.");
|
||||
|
||||
/// Number of times this iterator can be incremented
|
||||
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
|
||||
|
||||
/// Number of mma operations performed by a warp
|
||||
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
||||
Shape::kColumn / InstructionShape::kN>;
|
||||
/// Number of mma operations performed by the entire accumulator
|
||||
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
|
||||
AccumulatorShape::kColumn / InstructionShape::kN>;
|
||||
|
||||
/// Number of K iterations
|
||||
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
|
||||
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
|
||||
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
static int const kResidualIndex = kResidualColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
|
||||
/// Accumulator Fragment object
|
||||
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Accumulator tile
|
||||
AccessType const *accumulators_;
|
||||
|
||||
/// Internal index
|
||||
int index_;
|
||||
|
||||
/// Used to access residual tile first
|
||||
bool is_residual_tile_;
|
||||
|
||||
public:
|
||||
/// Constructs an iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
|
||||
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
|
||||
index_(0), is_residual_tile_(true) {}
|
||||
|
||||
/// Add offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_offset(int index_offset) {
|
||||
index_ += index_offset;
|
||||
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
|
||||
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
|
||||
is_residual_tile_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator &operator++() {
|
||||
add_offset(1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Decrements
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator &operator--() {
|
||||
add_offset(-1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from the referenced part of the accumulator tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
AccessType src_fragment;
|
||||
src_fragment.clear();
|
||||
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
|
||||
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
|
||||
* MmaIterations::kColumn;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; n++) {
|
||||
for (int m = 0; m < MmaIterations::kRow; m++) {
|
||||
int accumulator_access_offset =
|
||||
(n + index_n) * AccumulatorIterations::kRow + m + index_m;
|
||||
|
||||
frag_ptr[n * MmaIterations::kRow + m].clear();
|
||||
if(!(is_residual_tile_ && index_ >= kResidualIndex))
|
||||
frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset];
|
||||
// frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// Partial specialization for row-major accumulator tile
|
||||
|
||||
template <
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
||||
typename AccumulatorShape_,
|
||||
/// KBlocks columns to compute residual
|
||||
int KBlocksColumn_,
|
||||
/// Accumulator Element type
|
||||
typename ElementAccumulator_,
|
||||
/// Element type
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_>
|
||||
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
|
||||
cutlass::layout::RowMajor,
|
||||
InstructionShape_, true> {
|
||||
public:
|
||||
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Shape of the warp accumulation tile (concept: MatrixShape)
|
||||
using AccumulatorShape = AccumulatorShape_;
|
||||
|
||||
/// KBlocks columns to compute residual
|
||||
static int const kKBlockColumn = KBlocksColumn_;
|
||||
|
||||
/// Accumulator Element type
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Whether beta is zero
|
||||
static bool const IsBetaZero = true;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
!(Shape::kRow % InstructionShape::kM) &&
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
static_assert(
|
||||
!(AccumulatorShape::kRow % Shape::kRow) &&
|
||||
!(AccumulatorShape::kColumn % Shape::kColumn),
|
||||
"Shape of Warp Accumulator must be divisible by warp shape.");
|
||||
static_assert(
|
||||
!(kKBlockColumn % Shape::kColumn),
|
||||
"KBlock size must be divisible by warp shape.");
|
||||
|
||||
/// Number of times this iterator can be incremented
|
||||
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
|
||||
|
||||
/// Number of mma operations performed by a warp
|
||||
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
||||
Shape::kColumn / InstructionShape::kN>;
|
||||
/// Number of mma operations performed by the entire accumulator
|
||||
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
|
||||
AccumulatorShape::kColumn / InstructionShape::kN>;
|
||||
|
||||
/// Number of K iterations
|
||||
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
|
||||
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
|
||||
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
static int const kResidualIndex = kResidualColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
|
||||
/// Accumulator Fragment object
|
||||
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentAccessType = Array<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Accumulator tile
|
||||
AccessType const *accumulators_;
|
||||
|
||||
/// Internal index
|
||||
int index_;
|
||||
|
||||
/// Used to access residual tile first
|
||||
bool is_residual_tile_;
|
||||
|
||||
public:
|
||||
/// Constructs an iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
|
||||
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
|
||||
index_(0), is_residual_tile_(true) {}
|
||||
|
||||
/// Add offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_offset(int index_offset) {
|
||||
index_ += index_offset;
|
||||
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
|
||||
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
|
||||
is_residual_tile_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator &operator++() {
|
||||
add_offset(1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Decrements
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpPureFragmentIterator &operator--() {
|
||||
add_offset(-1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from the referenced part of the accumulator tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
|
||||
FragmentAccessType src_fragment;
|
||||
src_fragment.clear();
|
||||
|
||||
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
|
||||
|
||||
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
|
||||
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
|
||||
* MmaIterations::kColumn;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; m++) {
|
||||
for (int n = 0; n < MmaIterations::kColumn; n++) {
|
||||
int accumulator_access_offset =
|
||||
(m + index_m) * AccumulatorIterations::kColumn + n + index_n;
|
||||
|
||||
frag_ptr[m * MmaIterations::kColumn + n].clear();
|
||||
if(!(is_residual_tile_ && index_ >= kResidualIndex))
|
||||
frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
129
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py
Normal file
129
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py
Normal file
@ -0,0 +1,129 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import gen_turing_and_volta as api_generator
|
||||
import gen_sample as sample_creater
|
||||
import gen_cmake as cmake_creater
|
||||
import gen_verify as verify_creater
|
||||
import gen_device as b2b_fused_generator
|
||||
import replace_fix_impl_header
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels")
|
||||
parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate")
|
||||
parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name")
|
||||
parser.add_argument("--output-dir", default="", help="Specifies the output dir")
|
||||
parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir")
|
||||
parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.")
|
||||
args = parser.parse_args()
|
||||
|
||||
gen_name = args.gen_name
|
||||
|
||||
cutlass_deps_dir = args.cutlass_dir
|
||||
|
||||
output_dir = args.output_dir
|
||||
output_dir += "/"
|
||||
|
||||
cutlass_deps_root = args.gen_include_cutlass_dir
|
||||
if cutlass_deps_root == '':
|
||||
cutlass_deps_root = cutlass_deps_dir + "/include/"
|
||||
cutlass_deps_root +='/'
|
||||
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if not os.path.exists(output_dir + "/" + "auto_gen"):
|
||||
os.mkdir(output_dir + "/" + "auto_gen")
|
||||
|
||||
if not os.path.exists(output_dir + "/" + "fixed_impl"):
|
||||
os.mkdir(output_dir + "/" + "fixed_impl" )
|
||||
|
||||
if not os.path.exists(output_dir + "/" + "sample"):
|
||||
os.mkdir(output_dir + "/" + "sample" )
|
||||
|
||||
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"):
|
||||
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device")
|
||||
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"):
|
||||
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel")
|
||||
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"):
|
||||
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock")
|
||||
|
||||
with open(args.config_file, 'r') as infile:
|
||||
gemm_info_dict = json.load(infile)
|
||||
|
||||
keys = sorted(gemm_info_dict.keys())
|
||||
fuse_gemm_info = [gemm_info_dict[k] for k in keys]
|
||||
|
||||
|
||||
for_cutlass_gen_user_include_header_file = [
|
||||
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
|
||||
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
|
||||
]
|
||||
|
||||
for_fused_wrapper = [
|
||||
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
|
||||
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
|
||||
"auto_gen/device/" + gen_name + ".h",
|
||||
cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h",
|
||||
cutlass_deps_root + "cutlass/cutlass.h",
|
||||
]
|
||||
|
||||
# Copy fixed implementation to the output directory
|
||||
fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root)
|
||||
fix_impl.gen_code()
|
||||
|
||||
auto_gen_output_dir = output_dir + "/auto_gen/"
|
||||
project_root = ""
|
||||
turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir)
|
||||
turing_plus.gen_code(75, 'hmma1688', False)
|
||||
|
||||
api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
|
||||
api.gen_code()
|
||||
|
||||
# Generate C++ sample
|
||||
os.system("cp ../leaky_bias.h " + output_dir + "/sample/")
|
||||
os.system("cp ../utils.h " + output_dir + "/sample/")
|
||||
|
||||
sample_dir = output_dir + "/sample/"
|
||||
sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir)
|
||||
sample.gen_cpp_sample()
|
||||
|
||||
cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir)
|
||||
cmake_gen.gen_code()
|
||||
|
||||
verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
|
||||
verify.gen_code()
|
||||
131
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py
Normal file
131
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py
Normal file
@ -0,0 +1,131 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
class gen_build_sys:
|
||||
def __init__(self, cutlass_deps_dir, output_dir = "../"):
|
||||
self.output_dir = output_dir
|
||||
self.cutlass_deps_dir = cutlass_deps_dir
|
||||
|
||||
def gen_top(self):
|
||||
code = ""
|
||||
code += '''\
|
||||
# Auto Generated code - Do not edit.
|
||||
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA)
|
||||
find_package(CUDAToolkit)
|
||||
set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}})
|
||||
set(CUTLASS_PATH \"{cutlass_deps_dir}/include\")
|
||||
set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\")
|
||||
list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}})
|
||||
'''.format(cutlass_deps_dir=self.cutlass_deps_dir)
|
||||
|
||||
code += '''\
|
||||
set(GPU_ARCHS \"\" CACHE STRING
|
||||
\"List of GPU architectures (semicolon-separated) to be compiled for.\")
|
||||
|
||||
if(\"${GPU_ARCHS}\" STREQUAL \"\")
|
||||
set(GPU_ARCHS \"70\")
|
||||
endif()
|
||||
|
||||
foreach(arch ${GPU_ARCHS})
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\")
|
||||
if(SM STREQUAL 70 OR SM STREQUAL 75)
|
||||
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\")
|
||||
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\")
|
||||
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
if(CMAKE_CXX_STANDARD STREQUAL \"11\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\")
|
||||
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\")
|
||||
|
||||
set(COMMON_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${CUDAToolkit_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
set(COMMON_LIB_DIRS
|
||||
${CUDAToolkit_LIBRARY_DIR}
|
||||
)
|
||||
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH})
|
||||
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH})
|
||||
'''
|
||||
code += '''\
|
||||
include_directories(
|
||||
${COMMON_HEADER_DIRS}
|
||||
)
|
||||
|
||||
link_directories(
|
||||
${COMMON_LIB_DIRS}
|
||||
)
|
||||
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_definitions(-DGOOGLE_CUDA=1)
|
||||
|
||||
add_executable(sample
|
||||
sample/sample.cu
|
||||
one_api.cu
|
||||
)
|
||||
target_link_libraries(sample PRIVATE
|
||||
-lcudart
|
||||
-lnvToolsExt
|
||||
${CMAKE_THREAD_LIBS_INIT}
|
||||
)
|
||||
|
||||
if(NOT DEFINED LIB_INSTALL_PATH)
|
||||
set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR})
|
||||
endif()
|
||||
'''
|
||||
return code
|
||||
|
||||
def gen_code(self):
|
||||
top_code = self.gen_top()
|
||||
with open(self.output_dir + "CMakeLists.txt", "w") as f:
|
||||
f.write(top_code)
|
||||
@ -0,0 +1,120 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ast
|
||||
|
||||
fuse_gemm_info = [
|
||||
{
|
||||
'epilogue': {
|
||||
'tp': 'LeakyRelu', #'CustomizedLeaky_RELU'
|
||||
'bias': {'addbias': False, 'bias_tp': 'mat'},
|
||||
'args': [('float', 'leaky_alpha', 1.3), ],
|
||||
'func': '''
|
||||
y = max(leaky_alpha * x, x)
|
||||
y = y * x
|
||||
'''
|
||||
}
|
||||
},
|
||||
|
||||
]
|
||||
class AnalysisNodeVisitor(ast.NodeVisitor):
|
||||
def visit_Import(self,node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_ImportFrom(self,node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Assign(self,node):
|
||||
print('Node type: Assign and fields: ', node._fields)
|
||||
# print('Node type: Assign and targets value: ', node.targets, node.value)
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
print('Node type: BinOp and fields: ', node._fields)
|
||||
print('node op: ', type(node.op).__name__)
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
print('Node type: Expr and fields: ', node._fields)
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Num(self,node):
|
||||
print('Node type: Num and fields: ', node._fields)
|
||||
print('Node type: Num: ', node.n)
|
||||
|
||||
def visit_Name(self,node):
|
||||
print('Node type: Name and fields: ', node._fields)
|
||||
print('Node type: Name and fields: ', type(node.ctx).__name__, node.id)
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Str(self, node):
|
||||
print('Node type: Str and fields: ', node._fields)
|
||||
|
||||
class CodeVisitor(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
if isinstance(node.op, ast.Add):
|
||||
node.op = ast.Sub()
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
print('Assign %s' % node.value)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
print("Name:", node.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
print('Function Name:%s'% node.name.op)
|
||||
self.generic_visit(node)
|
||||
func_log_stmt = ast.Print(
|
||||
dest = None,
|
||||
values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
|
||||
nl = True,
|
||||
lineno = 0,
|
||||
col_offset = 0,
|
||||
)
|
||||
node.body.insert(0, func_log_stmt)
|
||||
|
||||
visitor = AnalysisNodeVisitor()
|
||||
|
||||
code = \
|
||||
'''
|
||||
|
||||
a=max(leaky_alpha * x, x +1)
|
||||
|
||||
'''
|
||||
|
||||
visitor.visit(ast.parse(code))
|
||||
477
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py
Normal file
477
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py
Normal file
@ -0,0 +1,477 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from typing import *
|
||||
|
||||
import helper
|
||||
import gen_ir
|
||||
|
||||
import gen_kernel as gen_ker
|
||||
|
||||
|
||||
class gen_device:
|
||||
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.raw_gemm_info = fuse_gemm_info
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
self.user_header_file = user_header_file
|
||||
self.args = {}
|
||||
# device arg struct memebr
|
||||
self.arg_member = []
|
||||
self.gen_class_name = gen_class_name
|
||||
self.gen_kernel_name = gen_class_name + "Kernel"
|
||||
self.tempalte_args = []
|
||||
self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
|
||||
|
||||
self.file_name = output_dir + "/device/" +gen_class_name +".h"
|
||||
self.sample_dir = output_dir
|
||||
|
||||
|
||||
self.cutlass_deps_root = cutlass_deps_root
|
||||
self.project_root = project_root
|
||||
self.this_file_root = output_dir + "/device/"
|
||||
|
||||
self.first_use_1stage = False
|
||||
|
||||
## gen kernel
|
||||
self.gen_kernel = gen_ker.gen_kernel(self.tempalte_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
|
||||
|
||||
|
||||
def __check_arg_type(self, temp_arg):
|
||||
if temp_arg in self.__tempalate_arg_list.keys():
|
||||
return self.__tempalate_arg_list[temp_arg]
|
||||
|
||||
find_sub = False
|
||||
for candidate_arg in self.__tempalate_arg_list.keys():
|
||||
if (temp_arg.find(candidate_arg) != -1):
|
||||
return self.__tempalate_arg_list[candidate_arg]
|
||||
|
||||
return 'typename'
|
||||
|
||||
# def gen_B2b2bGemm_class():
|
||||
def set_arch(self, sm_cap, mma_tp):
|
||||
if sm_cap == 75 or sm_cap == 80 or sm_cap == 86:
|
||||
self.arch = "cutlass::arch::Sm" + str(sm_cap)
|
||||
|
||||
if mma_tp is 'hmma1688':
|
||||
self.mma_shape = [16, 8, 8]
|
||||
self.mma_tp = 'hmma'
|
||||
elif mma_tp is 'imma8816':
|
||||
self.mma_tp = 'imma'
|
||||
self.mma_shape = [8, 8, 16]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def gen_include_header(self):
|
||||
code = '''\
|
||||
/* Auto Generated code - Do not edit.*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include \"{cutlass_root}cutlass/cutlass.h\"
|
||||
#include \"{cutlass_root}cutlass/numeric_types.h\"
|
||||
#include \"{cutlass_root}cutlass/arch/arch.h\"
|
||||
#include \"{cutlass_root}cutlass/device_kernel.h\"
|
||||
|
||||
#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\"
|
||||
|
||||
#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\"
|
||||
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\"
|
||||
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\"
|
||||
|
||||
#include \"{project_root}../kernel/b2b_gemm.h\"
|
||||
#include \"{project_root}../kernel/default_b2b_gemm.h\"
|
||||
'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root)
|
||||
include_user_header = ""
|
||||
for header in self.user_header_file:
|
||||
include_user_header += "#include \"" + header + "\"\n"
|
||||
return code + include_user_header
|
||||
|
||||
def gen_code(self, sm_cap, mma_tp, ifprint = True):
|
||||
self.set_arch(sm_cap, mma_tp)
|
||||
|
||||
self.update_b2b_args()
|
||||
print(self.fuse_gemm_info)
|
||||
self.update_b2b_class_template_args()
|
||||
|
||||
func_code = self.gen_all_func()
|
||||
member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
|
||||
|
||||
gen_code = gen_ir.gen_template_class(self.gen_class_name, self.tempalte_args, func_code + member_var_code)
|
||||
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
|
||||
|
||||
if ifprint:
|
||||
print(code)
|
||||
|
||||
print("[INFO]: Gen device code output Dir: is ", self.file_name)
|
||||
with open(self.file_name, 'w+') as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage)
|
||||
print(gen_kernel)
|
||||
|
||||
def update_b2b_class_template_args(self):
|
||||
for arg in self.args.keys():
|
||||
self.tempalte_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
|
||||
|
||||
def update_b2b_args(self):
|
||||
|
||||
self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp'])
|
||||
self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format'])
|
||||
|
||||
cnt = 0
|
||||
|
||||
warp_M_tile = 32
|
||||
|
||||
# Determine maxmimum N_tile
|
||||
Max_Ntile = 0
|
||||
for layer in self.fuse_gemm_info:
|
||||
n_tile = layer['mnk'][1]
|
||||
if n_tile > Max_Ntile:
|
||||
Max_Ntile = n_tile
|
||||
if Max_Ntile >= 256:
|
||||
warp_M_tile = 16
|
||||
|
||||
stages_temp = []
|
||||
|
||||
for layer in self.fuse_gemm_info:
|
||||
cnt_str = str(cnt)
|
||||
B_tp_str= 'ElementB' + cnt_str
|
||||
B_format_str = 'LayoutB' + cnt_str
|
||||
C_tp_str= 'ElementC' + cnt_str
|
||||
C_format_str = 'LayoutC' + cnt_str
|
||||
Acc_str = 'ElementAccumulator' + cnt_str
|
||||
|
||||
self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp'])
|
||||
self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format'])
|
||||
self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp'])
|
||||
self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format'])
|
||||
self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp'])
|
||||
|
||||
|
||||
mnk = layer['mnk'][:]
|
||||
|
||||
tile_mnk = mnk[:]
|
||||
|
||||
tile_mnk[2] = 32 # force the ktile is 32
|
||||
|
||||
#N tile gen
|
||||
if mnk[1] > 1024:
|
||||
assert(0)
|
||||
elif mnk[1] > 512:
|
||||
tile_mnk[1] = 1024
|
||||
elif mnk[1] > 256:
|
||||
tile_mnk[1] = 512
|
||||
elif mnk[1] > 128:
|
||||
tile_mnk[1] = 256
|
||||
elif mnk[1] > 64:
|
||||
tile_mnk[1] = 128
|
||||
elif mnk[1] > 32:
|
||||
tile_mnk[1] = 64
|
||||
else :
|
||||
tile_mnk[1] = 32
|
||||
|
||||
if tile_mnk[1] == 512:
|
||||
stages_temp.append(1)
|
||||
else:
|
||||
stages_temp.append(2)
|
||||
|
||||
tile_mnk[0] = 4 * warp_M_tile
|
||||
|
||||
|
||||
|
||||
epilogue_setted_type = helper.get_epilogue_tp(layer)
|
||||
cutlass_epilogue_name = "LinearCombinationRelu"
|
||||
if epilogue_setted_type.lower() == 'leakyrelu':
|
||||
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
|
||||
elif epilogue_setted_type.lower() == 'identity':
|
||||
cutlass_epilogue_name = "LinearCombination"
|
||||
|
||||
epilogue_str = 'EpilogueOutputOp' + cnt_str
|
||||
if cnt != len(self.fuse_gemm_info) - 1:
|
||||
n = layer['mnk'][1]
|
||||
Fragments = tile_mnk[1] // 8 * 2
|
||||
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "<ElementC0_, " + str(Fragments) +", ElementAccumulator0_, ElementAccumulator0_>"
|
||||
else:
|
||||
n = layer['mnk'][1]
|
||||
n_mod_8 = n % 4
|
||||
N_align_elements = 1
|
||||
if n_mod_8 == 0:
|
||||
N_align_elements = 8
|
||||
elif n_mod_8 == 4:
|
||||
N_align_elements = 4
|
||||
elif n_mod_8 == 2 or n_mod_8 == 6:
|
||||
N_align_elements = 2
|
||||
|
||||
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<ElementC0_, " + str(N_align_elements) + ", ElementAccumulator0_, ElementAccumulator0_>"
|
||||
|
||||
|
||||
|
||||
ThreadBlockShape_str = 'ThreadblockShape' + cnt_str
|
||||
|
||||
self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
|
||||
|
||||
WarpShape_str = 'WarpShape' + cnt_str
|
||||
tile_mnk[0] = warp_M_tile
|
||||
self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
|
||||
cnt += 1
|
||||
|
||||
|
||||
self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp'])
|
||||
self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format'])
|
||||
|
||||
self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape)
|
||||
self.args['OperatorClass'] = 'arch::OpClassTensorOp'
|
||||
self.args['ArchTag'] = self.arch
|
||||
self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle'
|
||||
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
self.args[helper.var_idx('Stages', i)] = "2"
|
||||
|
||||
self.args['AlignmentA'] = str(8)
|
||||
self.args['AlignmentB'] = str(8)
|
||||
self.args['SplitKSerial'] = 'false'
|
||||
self.args['Operator'] = 'typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB0_, ElementC0_, ElementAccumulator0_>::Operator'
|
||||
self.args['IsBetaZero'] = 'false'
|
||||
|
||||
|
||||
def gen_using_kernel(self):
|
||||
code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n"
|
||||
code += " " + "ElementA,\n"
|
||||
code += " " + "LayoutA,\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + helper.var_idx("ElementB", i) + ",\n"
|
||||
code += " " + helper.var_idx("LayoutB", i) + ",\n"
|
||||
code += " " + helper.var_idx("ElementC", i) + ",\n"
|
||||
code += " " + helper.var_idx("LayoutC", i) + ",\n"
|
||||
code += " " + helper.var_idx("ElementAccumulator", i) + ",\n"
|
||||
code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n"
|
||||
code += " " + helper.var_idx("ThreadblockShape", i) + ",\n"
|
||||
code += " " + helper.var_idx("WarpShape", i) + ",\n"
|
||||
|
||||
code += " " + "ElementD,\n"
|
||||
code += " " + "LayoutD,\n"
|
||||
code += " " + "InstructionShape,\n"
|
||||
code += " " + "OperatorClass,\n"
|
||||
code += " " + "ArchTag,\n"
|
||||
code += " " + "ThreadblockSwizzle,\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + helper.var_idx("Stages", i) + ",\n"
|
||||
|
||||
|
||||
code += " " + "AlignmentA,\n"
|
||||
code += " " + "AlignmentB,\n"
|
||||
code += " " + "SplitKSerial,\n"
|
||||
code += " " + "Operator,\n"
|
||||
code += " " + "IsBetaZero_\n"
|
||||
|
||||
code += ">::B2bGemmKernel;\n\n"
|
||||
|
||||
return code
|
||||
|
||||
def gen_args(self):
|
||||
|
||||
def gen_arg_member(b2b_num):
|
||||
data_members = []
|
||||
|
||||
for i in range(b2b_num):
|
||||
member_type = "GemmCoord"
|
||||
member_name = "problem_size_" + str(i)
|
||||
data_members.append((member_type, member_name))
|
||||
|
||||
member_type = "TensorRef<ElementA const, LayoutA>"
|
||||
member_name = "ref_A0"
|
||||
data_members.append((member_type, member_name))
|
||||
|
||||
for i in range(b2b_num):
|
||||
member_type = "TensorRef<ElementB" + str(i) + " const, LayoutB" + str(i) +">"
|
||||
member_name = "ref_B" + str(i)
|
||||
data_members.append((member_type, member_name))
|
||||
member_type = "TensorRef<ElementC" + str(i) + " const, LayoutC" + str(i) +">"
|
||||
member_name = "ref_C" + str(i)
|
||||
data_members.append((member_type, member_name))
|
||||
|
||||
member_type = "TensorRef<ElementD, LayoutD>"
|
||||
member_name = helper.var_idx("ref_D", b2b_num - 1)
|
||||
data_members.append((member_type, member_name))
|
||||
|
||||
for i in range(b2b_num):
|
||||
member_type = "typename EpilogueOutputOp" + str(i) + "::Params"
|
||||
member_name = "epilogue" + str(i)
|
||||
data_members.append((member_type, member_name))
|
||||
|
||||
data_members.append(('int', 'batch_count'))
|
||||
|
||||
return data_members
|
||||
|
||||
def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value):
|
||||
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
|
||||
gen_ir.indentation + struct_name + " (): "
|
||||
for i in range(inital_param_num):
|
||||
final_param = ','
|
||||
if i == inital_param_num - 1:
|
||||
final_param = '{ }'
|
||||
constructs_code += data_members[i][1] + inital_value + final_param
|
||||
|
||||
constructs_code += "\n"
|
||||
return constructs_code
|
||||
|
||||
def gen_arg_struct_ctor(struct_name, data_members):
|
||||
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
|
||||
gen_ir.indentation + struct_name + " (\n"
|
||||
cnt = 0
|
||||
param_num = len(data_members)
|
||||
for param in data_members:
|
||||
final = ',\n'
|
||||
if cnt == param_num - 1:
|
||||
final = '\n):\n'
|
||||
constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final
|
||||
cnt += 1
|
||||
|
||||
cnt = 0
|
||||
for param in data_members:
|
||||
final = '),\n'
|
||||
if cnt == param_num - 1:
|
||||
final = ") { }\n"
|
||||
constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final
|
||||
cnt += 1
|
||||
|
||||
constructs_code += "\n"
|
||||
return constructs_code
|
||||
|
||||
# (variable type, variable name)
|
||||
struct_member = gen_arg_member(self.b2b_num)
|
||||
self.arg_member = struct_member
|
||||
|
||||
codeBody = ""
|
||||
for each_member in struct_member:
|
||||
codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n"
|
||||
|
||||
codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n"
|
||||
codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n"
|
||||
struct_code = gen_ir.gen_struct("Arguments", codeBody)
|
||||
return struct_code
|
||||
|
||||
def gen_func_constructs(self):
|
||||
code = self.gen_class_name +"() {}"
|
||||
return code
|
||||
|
||||
def gen_func_initialize(self):
|
||||
code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \
|
||||
"// Determine grid shape\n" + \
|
||||
"ThreadblockSwizzle threadblock_swizzle;\n" + \
|
||||
"cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \
|
||||
" args.problem_size_0, \n" + \
|
||||
" { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \
|
||||
" args.batch_count);\n" + \
|
||||
"// Initialize the Params structure\n" + \
|
||||
"params_ = typename B2bGemmKernel::Params{\n"
|
||||
for i in range(self.b2b_num):
|
||||
code += helper.var_idx(" args.problem_size_", i) + ",\n"
|
||||
code += " grid_shape,\n" + \
|
||||
" args.ref_A0.non_const_ref(),\n"
|
||||
for i in range(self.b2b_num):
|
||||
code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n"
|
||||
code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n"
|
||||
|
||||
code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n"
|
||||
for i in range(self.b2b_num):
|
||||
code += helper.var_idx(" args.epilogue", i) + ",\n"
|
||||
|
||||
code += " args.batch_count\n"
|
||||
code += "};\n" + \
|
||||
"return Status::kSuccess;\n" + \
|
||||
"}\n"
|
||||
return code
|
||||
|
||||
def gen_func_run(self):
|
||||
code = "Status run(cudaStream_t stream = nullptr) {\n" + \
|
||||
"\n" + \
|
||||
" ThreadblockSwizzle threadblock_swizzle;\n" + \
|
||||
"\n" + \
|
||||
" dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \
|
||||
" dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \
|
||||
"\n" + \
|
||||
" cudaError_t result;\n" + \
|
||||
"\n" + \
|
||||
" int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \
|
||||
" if (smem_size >= (48 << 10)) {\n" + \
|
||||
" result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \
|
||||
"\n" + \
|
||||
" if (result != cudaSuccess) {\n" + \
|
||||
" return Status::kErrorInternal;\n" + \
|
||||
" }\n" + \
|
||||
"\n" + \
|
||||
" result = cudaFuncSetAttribute(\n" + \
|
||||
" Kernel<B2bGemmKernel>,\n" + \
|
||||
" cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \
|
||||
"\n" + \
|
||||
" if (result != cudaSuccess) {\n" + \
|
||||
" return Status::kErrorInternal;\n" + \
|
||||
" }\n" + \
|
||||
" }\n" + \
|
||||
" cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
|
||||
" result = cudaGetLastError();\n" + \
|
||||
" return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \
|
||||
" }\n"
|
||||
|
||||
return code
|
||||
def gen_func_operator(self):
|
||||
opeartor_with_arg_code = "Status operator()(\n" + \
|
||||
" Arguments const &args,\n" + \
|
||||
" void *workspace = nullptr,\n" + \
|
||||
" cudaStream_t stream = nullptr) {\n" + \
|
||||
" Status status = initialize(args, workspace);\n" + \
|
||||
" \n" + \
|
||||
" if (status == Status::kSuccess) {\n" + \
|
||||
" status = run(stream);\n" + \
|
||||
" }\n" + \
|
||||
" return status;\n" + \
|
||||
"}\n"
|
||||
operator_code = "Status operator()(\n" + \
|
||||
" cudaStream_t stream = nullptr) {\n" + \
|
||||
" Status status = run(stream);\n" + \
|
||||
" return status;\n" + \
|
||||
"}\n"
|
||||
return opeartor_with_arg_code + "\n" + operator_code
|
||||
|
||||
def gen_all_func(self):
|
||||
return self.gen_using_kernel() + "\n" + \
|
||||
self.gen_args() + "\n" + \
|
||||
self.gen_func_constructs() + "\n" + \
|
||||
self.gen_func_initialize() + "\n" + \
|
||||
self.gen_func_run() + "\n" + \
|
||||
self.gen_func_operator()
|
||||
249
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py
Normal file
249
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py
Normal file
@ -0,0 +1,249 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import helper
|
||||
|
||||
|
||||
indentation = " "
|
||||
|
||||
|
||||
def append_word(word):
|
||||
code = ""
|
||||
code += word
|
||||
code += " "
|
||||
return code
|
||||
|
||||
|
||||
def gen_namespace(namespace, codeBody):
|
||||
code_gen = "namespace " + namespace + " {\n"
|
||||
code_gen += codeBody
|
||||
code_gen += "} // namespace " + namespace + "\n"
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_expression(type, lval, rval = None):
|
||||
code_gen = ""
|
||||
code_gen += append_word(type)
|
||||
code_gen += append_word(lval)
|
||||
if rval is not None:
|
||||
code_gen += append_word("=")
|
||||
code_gen += append_word(rval)
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_class(name, codeBody, inheritance_code = None):
|
||||
code_gen = ""
|
||||
if inheritance_code is None:
|
||||
code_gen = "class " + name + "{\n"
|
||||
else:
|
||||
code_gen = "class " + name + " : "+ inheritance_code + "{\n"
|
||||
code_gen += codeBody
|
||||
code_gen += "}; // class " + name + "\n"
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_struct(name, codeBody, specialized = None):
|
||||
specialized_code = ""
|
||||
if specialized is not None:
|
||||
specialized_code = "<" + specialized + ">"
|
||||
code_gen = "struct " + name + specialized_code + "{\n"
|
||||
code_gen += codeBody
|
||||
code_gen += "}; // struct " + name + "\n"
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_template_arg(arg_type, arg_name, default_val = None):
|
||||
rval = None
|
||||
if default_val is not None:
|
||||
rval = str(default_val)
|
||||
|
||||
arg_typename = ""
|
||||
if arg_type is int:
|
||||
arg_typename = "int"
|
||||
elif arg_type is bool:
|
||||
arg_typename = "bool"
|
||||
else:
|
||||
arg_typename = "typename"
|
||||
|
||||
internal_arg_name = arg_name + "_"
|
||||
|
||||
code_gen = indentation
|
||||
code_gen += gen_expression(arg_typename, internal_arg_name, rval)
|
||||
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_template_args(args, set_default = True):
|
||||
arg_len = len(args)
|
||||
cnt = 1
|
||||
code_gen = ""
|
||||
for arg_tuple in args:
|
||||
arg_type = arg_tuple[0]
|
||||
arg_name = arg_tuple[1]
|
||||
arg_default_val = None
|
||||
if len(arg_tuple) == 3 and set_default:
|
||||
arg_default_val = arg_tuple[2]
|
||||
|
||||
code_gen += gen_template_arg(arg_type, arg_name, arg_default_val)
|
||||
if cnt != arg_len:
|
||||
code_gen += ",\n"
|
||||
cnt += 1
|
||||
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_template_head(args, set_default = True):
|
||||
code_gen = "template <\n"
|
||||
code_gen += gen_template_args(args, set_default)
|
||||
code_gen += ">\n"
|
||||
return code_gen
|
||||
|
||||
|
||||
def export_template_args(args):
|
||||
code_gen = "public:\n"
|
||||
for arg_tuple in args:
|
||||
code_gen += indentation
|
||||
arg_type = arg_tuple[0]
|
||||
arg_name = arg_tuple[1]
|
||||
internal_arg_name = arg_name + "_"
|
||||
|
||||
typename = ""
|
||||
if arg_type is int:
|
||||
typename = "static int const"
|
||||
elif arg_type is bool:
|
||||
typename = "static bool const"
|
||||
else:
|
||||
typename = "using"
|
||||
|
||||
code_gen += gen_expression(typename, arg_name, internal_arg_name)
|
||||
code_gen += ";\n"
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None):
|
||||
code_gen = ""
|
||||
|
||||
code_gen += gen_template_head(args, set_default)
|
||||
code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code)
|
||||
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True):
|
||||
code_gen = ""
|
||||
code_gen += gen_template_head(args, set_default)
|
||||
code = export_template_args(args) + codeBody
|
||||
if export_args is False:
|
||||
code = codeBody
|
||||
code_gen += gen_struct(struct_name, code , speicalized)
|
||||
|
||||
return code_gen
|
||||
|
||||
|
||||
def gen_declare_template_struct(name, *params):
|
||||
code = name + "<"
|
||||
cnt = 0
|
||||
param_num = len(params)
|
||||
for param in params:
|
||||
final = ", "
|
||||
if cnt == param_num - 1:
|
||||
final = ""
|
||||
code += param + final
|
||||
cnt += 1
|
||||
code += ">;\n"
|
||||
return code
|
||||
|
||||
|
||||
def filtered_param(params, name_and_value_pair, keep_ = False):
|
||||
rtn_template_args = []
|
||||
speicalized_template_args = []
|
||||
|
||||
for param in params:
|
||||
param_name = ""
|
||||
if len(param) >= 1:
|
||||
param_name = param[1]
|
||||
else:
|
||||
param_name = param[0]
|
||||
|
||||
hit_flag = False
|
||||
set_value = ""
|
||||
for n_v_pair in name_and_value_pair:
|
||||
|
||||
filter_name = n_v_pair[0]
|
||||
set_value = n_v_pair[1]
|
||||
|
||||
if param_name == (filter_name + "_") or param_name == filter_name :
|
||||
hit_flag = True
|
||||
break
|
||||
|
||||
|
||||
if hit_flag is False:
|
||||
rtn_template_args.append(param)
|
||||
|
||||
if hit_flag is True:
|
||||
speicalized_template_args.append(set_value)
|
||||
else:
|
||||
if keep_ is True:
|
||||
speicalized_template_args.append(param_name + "_")
|
||||
else:
|
||||
speicalized_template_args.append(param_name)
|
||||
|
||||
|
||||
specialized_template_arg_str = helper.list_2_string(speicalized_template_args)
|
||||
|
||||
return rtn_template_args, specialized_template_arg_str
|
||||
|
||||
|
||||
def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True):
|
||||
code = "void " + func_name + "(\n"
|
||||
for arg in arg_lists:
|
||||
arg_tp = arg[0]
|
||||
arg_nm = arg[1]
|
||||
code += " " + arg_tp + " " + arg_nm + ",\n"
|
||||
code += "cudaStream_t stream)"
|
||||
if only_declare :
|
||||
return code
|
||||
code += "{\n"
|
||||
|
||||
code += code_body + "\n"
|
||||
code += "}\n"
|
||||
return code
|
||||
|
||||
|
||||
def indent_level(code, level = 0):
|
||||
rtn_code = ""
|
||||
for i in range(level):
|
||||
rtn_code += " "
|
||||
|
||||
rtn_code += code
|
||||
|
||||
return rtn_code
|
||||
476
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py
Normal file
476
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py
Normal file
@ -0,0 +1,476 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import gen_ir
|
||||
import helper
|
||||
import gen_threadblock as gen_tb
|
||||
|
||||
|
||||
class gen_default_Gemm:
|
||||
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
||||
self.gen_class_name = "B2bGemm"
|
||||
self.template_param = template_param
|
||||
self.b2b_num = b2b_num
|
||||
|
||||
self.cutlass_deps_root = cutlass_deps_root
|
||||
self.project_root = project_root
|
||||
|
||||
def gen_B2bMma(self, specialized_template_args):
|
||||
code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n"
|
||||
code += specialized_template_args
|
||||
code += ">::ThreadblockB2bMma;\n"
|
||||
|
||||
# print(code)
|
||||
return code
|
||||
|
||||
def gen_epilogue(self):
|
||||
epilogue_code = ""
|
||||
epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n"
|
||||
|
||||
epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n"
|
||||
epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n"
|
||||
epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n"
|
||||
epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n"
|
||||
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n"
|
||||
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n"
|
||||
epilogue_code += ">::Epilogue;\n"
|
||||
|
||||
epilogue_code += "using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;\n\n"
|
||||
|
||||
return epilogue_code
|
||||
|
||||
|
||||
def gen_include_header(self):
|
||||
code = '''
|
||||
/* Auto Generated code - Do not edit.*/
|
||||
|
||||
#pragma once
|
||||
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
||||
|
||||
#include \"{cutlass_dir}cutlass/layout/matrix.h\"
|
||||
#include \"{cutlass_dir}cutlass/numeric_types.h\"
|
||||
|
||||
#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\"
|
||||
#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\"
|
||||
|
||||
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\"
|
||||
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\"
|
||||
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\"
|
||||
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\"
|
||||
|
||||
#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
|
||||
|
||||
#include \"../kernel/b2b_gemm.h\"
|
||||
#include \"../threadblock/default_b2b_mma.h\"
|
||||
'''.format(cutlass_dir=self.cutlass_deps_root)
|
||||
return code
|
||||
|
||||
def gen_code(self):
|
||||
gen_using = ''
|
||||
# Generate default template struct
|
||||
gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False)
|
||||
|
||||
|
||||
filter_list = []
|
||||
filter_list.append(('Stages', 2))
|
||||
filter_list.append(("OperatorClass", "arch::OpClassTensorOp"))
|
||||
filter_list.append(("ArchTag", "arch::Sm75"))
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor"))
|
||||
|
||||
|
||||
rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True)
|
||||
|
||||
|
||||
B2bMma_code = self.gen_B2bMma(speicalized_template_args)
|
||||
epilogue_and_rest_code = self.gen_epilogue()
|
||||
|
||||
gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False)
|
||||
|
||||
code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code)))
|
||||
|
||||
return self.gen_include_header() + code
|
||||
|
||||
|
||||
class gen_Kernel:
|
||||
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
|
||||
self.gen_class_name = "B2bGemm"
|
||||
self.template_param = template_param
|
||||
self.b2bnum = b2b_num
|
||||
|
||||
self.cutlass_deps_root = cutlass_deps_root
|
||||
self.project_root = project_root
|
||||
|
||||
def gen_include_header(self):
|
||||
code = '''
|
||||
#pragma once
|
||||
|
||||
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
||||
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
||||
#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root)
|
||||
return code
|
||||
|
||||
def gen_Params(self):
|
||||
gen_param = ""
|
||||
for i in range(self.b2bnum):
|
||||
gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n"
|
||||
gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n"
|
||||
gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n"
|
||||
gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n"
|
||||
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n"
|
||||
if i == self.b2bnum - 1:
|
||||
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n"
|
||||
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n"
|
||||
|
||||
else:
|
||||
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n"
|
||||
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n"
|
||||
|
||||
|
||||
|
||||
|
||||
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n"
|
||||
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n"
|
||||
|
||||
gen_param += " " + 'int batch_count' + ";\n"
|
||||
gen_param += " " + 'int gemm_k_iterations_0' + ";\n"
|
||||
|
||||
|
||||
return gen_param
|
||||
|
||||
def gen_Memberfunc(self):
|
||||
code_default = "\nCUTLASS_HOST_DEVICE\n"
|
||||
code_default += "Params()"
|
||||
|
||||
code_default += " { } \n\n"
|
||||
|
||||
code_construct = "\nCUTLASS_HOST_DEVICE\n"
|
||||
code_construct += "Params(\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n"
|
||||
|
||||
code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n"
|
||||
|
||||
code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n"
|
||||
if i == self.b2bnum - 1:
|
||||
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n"
|
||||
else:
|
||||
code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n"
|
||||
|
||||
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n"
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n"
|
||||
|
||||
code_construct += " " + "int batch_count = 1\n"
|
||||
|
||||
code_construct += "):\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n"
|
||||
|
||||
code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n"
|
||||
code_construct += " " + "params_A0(ref_A0.layout()),\n"
|
||||
code_construct += " " + "ref_A0(ref_A0),\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n"
|
||||
code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n"
|
||||
code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n"
|
||||
code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n"
|
||||
|
||||
code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n"
|
||||
code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n"
|
||||
|
||||
code_construct += " " + "batch_count(batch_count) {\n"
|
||||
code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n"
|
||||
|
||||
code_construct += "}\n"
|
||||
|
||||
return code_default + code_construct
|
||||
|
||||
def gen_using(self):
|
||||
code_using = ""
|
||||
|
||||
for i in range(self.b2bnum - 1):
|
||||
code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n"
|
||||
|
||||
code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n"
|
||||
|
||||
for i in range(self.b2bnum - 1):
|
||||
code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n"
|
||||
|
||||
|
||||
code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n"
|
||||
code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n"
|
||||
|
||||
code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc())
|
||||
|
||||
code_using += "union SharedStorage {\n"
|
||||
code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n"
|
||||
code_using += " " + "typename Epilogue::SharedStorage epilogue;\n"
|
||||
code_using += "};\n"
|
||||
|
||||
return code_using
|
||||
|
||||
def gen_can_implement(self):
|
||||
gen_code = ""
|
||||
return gen_code
|
||||
|
||||
def gen_operator_and_constr(self):
|
||||
ctr_code = "CUTLASS_HOST_DEVICE\n"
|
||||
ctr_code += self.gen_class_name + "() { } \n\n"
|
||||
operator_code = "CUTLASS_DEVICE\n"
|
||||
operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n"
|
||||
operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n"
|
||||
operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
|
||||
operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n"
|
||||
operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n"
|
||||
operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n"
|
||||
operator_code += " " + " " + "return;\n"
|
||||
operator_code += " " + "}\n"
|
||||
|
||||
operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n"
|
||||
operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n"
|
||||
operator_code += " " + " " + "0\n"
|
||||
operator_code += " " + "};\n"
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n"
|
||||
operator_code += " " + " " + "0,\n"
|
||||
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n"
|
||||
operator_code += " " + "};\n"
|
||||
|
||||
operator_code += " " + "int thread_idx = threadIdx.x;\n\n"
|
||||
|
||||
operator_code += " " + "MatrixCoord threadblock_offset(\n"
|
||||
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n"
|
||||
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n"
|
||||
operator_code += " " + ");\n"
|
||||
|
||||
operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n"
|
||||
operator_code += " " + " " + "params.params_A0,\n"
|
||||
operator_code += " " + " " + "params.ref_A0.data(),\n"
|
||||
operator_code += " " + " " + "params.problem_size_0.mk(),\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + "tb_offset_A0);\n"
|
||||
|
||||
operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n"
|
||||
|
||||
|
||||
for i in range (self.b2bnum):
|
||||
operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n"
|
||||
operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n"
|
||||
|
||||
|
||||
for i in range (self.b2bnum - 1):
|
||||
operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + "threadblock_offset" + ");\n"
|
||||
operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n"
|
||||
operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n"
|
||||
|
||||
|
||||
for i in range (self.b2bnum - 1):
|
||||
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
|
||||
|
||||
|
||||
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
|
||||
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
|
||||
|
||||
for i in range (self.b2bnum - 1):
|
||||
operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n"
|
||||
|
||||
operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n"
|
||||
|
||||
operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n"
|
||||
operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n"
|
||||
|
||||
operator_code += " " + "src_accum.clear();\n"
|
||||
operator_code += " " + "accumulators.clear();\n"
|
||||
operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, "
|
||||
|
||||
for i in range(self.b2bnum):
|
||||
operator_code += helper.var_idx("iterator_B", i) + ", "
|
||||
|
||||
operator_code += "src_accum"
|
||||
if self.b2bnum != 1:
|
||||
operator_code += ", "
|
||||
for i in range(self.b2bnum - 1):
|
||||
operator_code += helper.var_idx("output_op_", i) + ", "
|
||||
|
||||
for i in range(self.b2bnum - 1):
|
||||
operator_code += helper.var_idx("epilogue_", i) + ", "
|
||||
|
||||
for i in range(self.b2bnum - 1):
|
||||
final = ", "
|
||||
if i == self.b2bnum - 2:
|
||||
final =""
|
||||
operator_code += helper.var_idx("iterator_C", i) + final
|
||||
operator_code += ");\n"
|
||||
|
||||
operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n"
|
||||
operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
|
||||
|
||||
|
||||
|
||||
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + "threadblock_offset\n"
|
||||
operator_code += " " + ");\n"
|
||||
operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n"
|
||||
|
||||
operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n"
|
||||
|
||||
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n"
|
||||
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + "threadblock_offset\n"
|
||||
operator_code += " " + ");\n"
|
||||
operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n"
|
||||
|
||||
|
||||
operator_code += " " + "Epilogue epilogue(\n"
|
||||
operator_code += " " + " " + "shared_storage.epilogue,\n"
|
||||
operator_code += " " + " " + "thread_idx,\n"
|
||||
operator_code += " " + " " + "warp_idx,\n"
|
||||
operator_code += " " + " " + "lane_idx\n"
|
||||
operator_code += " " + ");\n"
|
||||
|
||||
operator_code += " " + "epilogue("
|
||||
operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", "
|
||||
operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", "
|
||||
operator_code += "accumulators, "
|
||||
operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n"
|
||||
operator_code += "}\n"
|
||||
|
||||
return ctr_code + operator_code
|
||||
|
||||
def gen_include_header(self):
|
||||
code = '''
|
||||
#pragma once
|
||||
|
||||
#include \"{cutlass_dir}cutlass/cutlass.h\"
|
||||
|
||||
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
|
||||
#include \"{cutlass_dir}cutlass/matrix_coord.h\"
|
||||
#include \"{cutlass_dir}cutlass/semaphore.h\"
|
||||
'''.format(cutlass_dir=self.cutlass_deps_root)
|
||||
return code
|
||||
def gen_code(self):
|
||||
|
||||
template_param = []
|
||||
template_param.append(("typename", "B2bMma"))
|
||||
template_param.append(("typename", "Epilogue"))
|
||||
template_param.append(("typename", "ThreadblockSwizzle"))
|
||||
template_param.append((bool, "SplitKSerial"))
|
||||
|
||||
code_body = ""
|
||||
code_body += self.gen_using()
|
||||
code_body += self.gen_operator_and_constr()
|
||||
|
||||
struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body)
|
||||
code = self.gen_include_header()
|
||||
code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code)))
|
||||
|
||||
return self.gen_include_header() + code
|
||||
|
||||
|
||||
|
||||
class gen_kernel:
|
||||
def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
|
||||
self.template_param = template_param
|
||||
|
||||
self.gen_class_name = "B2bGemm"
|
||||
self.gen_kernel_name = gen_class_name + "Kernel"
|
||||
self.tempalte_args = []
|
||||
|
||||
self.cutlass_deps_root = cutlass_deps_root
|
||||
self.project_root = project_root
|
||||
|
||||
self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
||||
self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
|
||||
|
||||
# Include gen_threadBlock
|
||||
self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root)
|
||||
|
||||
self.file_dir = output_dir + "/kernel/"
|
||||
|
||||
def gen_code(self, first_use_1stage):
|
||||
|
||||
default_b2b_gemm = self.gen_default_b2b_gemm.gen_code()
|
||||
|
||||
print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir)
|
||||
|
||||
with open(self.file_dir + "default_b2b_gemm.h", "w+") as f:
|
||||
f.write(default_b2b_gemm)
|
||||
|
||||
kernel = self.gen_Kerenl.gen_code()
|
||||
print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir)
|
||||
|
||||
with open(self.file_dir + "b2b_gemm.h", "w+") as f:
|
||||
f.write(kernel)
|
||||
|
||||
# Call code to gen threadblock
|
||||
self.gen_threadBlock.gen_code(first_use_1stage)
|
||||
232
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py
Normal file
232
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py
Normal file
@ -0,0 +1,232 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import helper
|
||||
import gen_ir as ir
|
||||
|
||||
class gen_test:
|
||||
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.gen_class_name = gen_class_name
|
||||
self.user_header_file = user_header_file
|
||||
self.sample_dir = output_dir
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
|
||||
def gen_cpp_sample(self):
|
||||
code = "/* Auto Generated code - Do not edit.*/\n"
|
||||
code += "#include <stdio.h> \n"
|
||||
|
||||
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
|
||||
code += "#include \"cutlass/cutlass.h\" \n"
|
||||
|
||||
code += "#include \"../cutlass_irrelevant.h\" \n"
|
||||
code += "#include \"../cutlass_verify.h\" \n"
|
||||
|
||||
code += "#include \"leaky_bias.h\" \n"
|
||||
|
||||
code += "#include \"utils.h\" \n"
|
||||
|
||||
|
||||
|
||||
code += "int main(int args, char * argv[]) {\n"
|
||||
code += " " + "int M = atoi(argv[1]);\n"
|
||||
code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n"
|
||||
code += " " + "if(args == 3);\n"
|
||||
code += " " + " " + "K0 = atoi(argv[2]);\n"
|
||||
code += " " + "int B = 1;\n"
|
||||
code += " " + "if(args == 4);\n"
|
||||
code += " " + " " + "B = atoi(argv[3]);\n"
|
||||
|
||||
code += " " + "srand(1234UL);\n"
|
||||
code += " " + "int device_id = 0;\n"
|
||||
code += " " + "cudaGetDevice(&device_id);\n"
|
||||
code += " " + "cudaDeviceProp prop;\n"
|
||||
code += " " + "cudaGetDeviceProperties(&prop, device_id);\n"
|
||||
code += " " + "int sm = prop.major *10 + prop.minor;\n"
|
||||
code += "using ElementCompute = cutlass::half_t;\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n"
|
||||
addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
|
||||
if addbias:
|
||||
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n"
|
||||
else:
|
||||
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n"
|
||||
|
||||
code += " " + "size_t flops = 0;\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
m = self.fuse_gemm_info[i]['mnk'][0]
|
||||
n = self.fuse_gemm_info[i]['mnk'][1]
|
||||
k = self.fuse_gemm_info[i]['mnk'][2]
|
||||
|
||||
bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])
|
||||
|
||||
this_k = "K0"
|
||||
if (i > 0):
|
||||
this_k = str(k)
|
||||
|
||||
code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n"
|
||||
|
||||
code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n"
|
||||
|
||||
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n"
|
||||
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n"
|
||||
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n"
|
||||
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n"
|
||||
|
||||
code += " " + helper.var_idx("Mat_A", i) + ".init();\n"
|
||||
code += " " + helper.var_idx("Mat_B", i) + ".init();\n"
|
||||
code += " " + helper.var_idx("Mat_C", i) + ".init();\n"
|
||||
|
||||
|
||||
|
||||
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n"
|
||||
|
||||
params = []
|
||||
params.append("M")
|
||||
params.append("B")
|
||||
|
||||
params.append("Mat_A0.device_ptr")
|
||||
for i in range(self.b2b_num):
|
||||
params.append(helper.var_idx("Mat_B", i) + ".device_ptr")
|
||||
params.append(helper.var_idx("Mat_C", i) + ".device_ptr")
|
||||
if i != self.b2b_num-1:
|
||||
params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr")
|
||||
params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr")
|
||||
|
||||
code += " " + "Param arguments = {\n"
|
||||
code += " " + " " + "M,\n"
|
||||
code += " " + " " + "K0,\n"
|
||||
code += " " + " " + "B,\n"
|
||||
|
||||
code += " " + " " + "reinterpret_cast<const void*>(Mat_A0.device_ptr),\n"
|
||||
cnt = 1
|
||||
for i in range(self.b2b_num):
|
||||
bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
|
||||
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n"
|
||||
cnt += 1
|
||||
if bias_flag:
|
||||
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n"
|
||||
cnt += 1
|
||||
else:
|
||||
code += " " + " " + "reinterpret_cast<const void*>(NULL),\n"
|
||||
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_value = str(arg[2])
|
||||
|
||||
code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n"
|
||||
|
||||
if i != self.b2b_num - 1:
|
||||
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n"
|
||||
else:
|
||||
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n"
|
||||
|
||||
|
||||
|
||||
|
||||
code += " " + "TI(FUSED_CUTLASS);\n"
|
||||
code += " " + "for(int i = 0; i < 100; i++){\n"
|
||||
code += " " + " " + "one_api(arguments, sm, NULL);\n"
|
||||
|
||||
code += " " + "}\n"
|
||||
code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n"
|
||||
|
||||
code += "\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code_this = ""
|
||||
|
||||
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
|
||||
code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
|
||||
code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n"
|
||||
ldmA = str(self.fuse_gemm_info[i]['mnk'][2])
|
||||
if i == 0:
|
||||
ldmA = "K0"
|
||||
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
|
||||
if i == 0:
|
||||
ldmB = "K0"
|
||||
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
|
||||
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
||||
|
||||
if self.fuse_gemm_info[i]['A_format'] is 'Col':
|
||||
ldmA = "M"
|
||||
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
||||
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
if self.fuse_gemm_info[i]['C_format'] is 'Col':
|
||||
ldmC = "M"
|
||||
|
||||
if i == 0:
|
||||
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
||||
else:
|
||||
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
||||
|
||||
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
|
||||
|
||||
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
|
||||
|
||||
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
|
||||
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n"
|
||||
code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
||||
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
||||
arg_value = str(epilogue_arg[2])
|
||||
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")"
|
||||
code_this += " " + " },\n"
|
||||
code_this += " " + " " + "B};\n"
|
||||
|
||||
code += code_this
|
||||
|
||||
|
||||
|
||||
code += " " + "TI(UNFUSED_CUTLASS);\n"
|
||||
code += " " + "for(int i = 0; i < 100; i++){\n"
|
||||
code += " " + " " + self.gen_class_name + "_verify(\n"
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n"
|
||||
code += " " + " " + " " + "NULL);\n"
|
||||
|
||||
code += " " + "}\n"
|
||||
code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n"
|
||||
|
||||
code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n"
|
||||
code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n"
|
||||
code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \
|
||||
+ helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n"
|
||||
|
||||
code += "\n\n}\n"
|
||||
|
||||
with open(self.sample_dir + "sample.cu", "w+") as f:
|
||||
f.write(code)
|
||||
1013
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py
Normal file
1013
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,456 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import helper
|
||||
import gen_ir as ir
|
||||
|
||||
class gen_turing_impl:
|
||||
def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.class_name = gen_class_name
|
||||
self.gen_class_name = gen_class_name + "_turing_impl"
|
||||
self.user_header_file = ""
|
||||
for header in user_header_file:
|
||||
self.user_header_file += "#include \"" + header + "\"\n"
|
||||
self.output_dir = output_dir
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
|
||||
self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
||||
|
||||
def gen_using(self):
|
||||
code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + "<cutlass::half_t>;"
|
||||
|
||||
return code_using + "\n"
|
||||
|
||||
def gen_initialize(self):
|
||||
code = ""
|
||||
for i in range(self.b2b_num):
|
||||
code_this = ""
|
||||
|
||||
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
|
||||
beta = "(1)"
|
||||
|
||||
if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False:
|
||||
beta = "(0)"
|
||||
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
|
||||
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
|
||||
if i == 0:
|
||||
k_str = "K0"
|
||||
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
|
||||
code += code_this
|
||||
code += "typename b2b_gemm::Arguments arguments{\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + helper.var_idx("problem_size_", i) + ",\n"
|
||||
|
||||
|
||||
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n"
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
|
||||
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
|
||||
if i == 0:
|
||||
ldmB = "K0"
|
||||
|
||||
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
||||
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
|
||||
ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
||||
|
||||
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n"
|
||||
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n"
|
||||
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n"
|
||||
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
||||
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
|
||||
code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
|
||||
code += "},\n"
|
||||
code += " " + "Batch};\n\n"
|
||||
|
||||
code += " " "b2b_gemm gemm_op;\n"
|
||||
code += " " + "gemm_op.initialize(arguments);\n"
|
||||
return code + "\n"
|
||||
|
||||
|
||||
|
||||
def gen_run(self):
|
||||
code = " " + "gemm_op(stream);\n"
|
||||
|
||||
return code
|
||||
|
||||
def gen_wrapper(self):
|
||||
code_body = ""
|
||||
|
||||
arg_lists = []
|
||||
arg_lists.append(["int", "M"])
|
||||
arg_lists.append(["int", "K0"])
|
||||
arg_lists.append(["int", "Batch"])
|
||||
arg_lists.append(["void*", helper.var_idx("A", 0)])
|
||||
for i in range(self.b2b_num):
|
||||
arg_lists.append(["void*", helper.var_idx("B", i)])
|
||||
arg_lists.append(["void*", helper.var_idx("C", i)])
|
||||
arg_lists.append(["void*", helper.var_idx("D", i)])
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_tp = arg[0]
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
||||
arg_lists.append([arg_tp, arg_name])
|
||||
|
||||
if self.b2b_num == 1:
|
||||
code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta
|
||||
code_body += self.gen_turing_unfused.gen_initialize()
|
||||
code_body += self.gen_turing_unfused.gen_run()
|
||||
else:
|
||||
code_body += self.gen_using()
|
||||
code_body += self.gen_initialize()
|
||||
code_body += self.gen_run()
|
||||
|
||||
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
|
||||
|
||||
return code
|
||||
|
||||
def gen_code(self):
|
||||
|
||||
code = self.gen_wrapper()
|
||||
helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code)
|
||||
|
||||
class gen_volta_turing_fuse_act_impl:
|
||||
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.gen_class_name = gen_class_name + "_volta_impl"
|
||||
self.user_header_file = ""
|
||||
for header in user_header_file:
|
||||
self.user_header_file += "#include \"" + header + "\"\n"
|
||||
self.output_dir = output_dir
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
|
||||
def perf_tiling(self, layer_mnk):
|
||||
mnk = layer_mnk[:]
|
||||
block_tile = mnk[:]
|
||||
block_tile[2] = 32 # force the K tile to be 32
|
||||
|
||||
# M tile gen
|
||||
block_tile[0] = 32
|
||||
|
||||
# N tile gen
|
||||
if mnk[1] > 128:
|
||||
block_tile[1] = 256
|
||||
elif mnk[1] > 64:
|
||||
block_tile[1] = 128
|
||||
elif mnk[1] > 32:
|
||||
block_tile[1] = 64
|
||||
else :
|
||||
block_tile[1] = 32
|
||||
|
||||
warp_tile = block_tile[:]
|
||||
if block_tile[1] == 256:
|
||||
warp_tile[1] = 64
|
||||
elif block_tile[1] == 128:
|
||||
warp_tile[1] = 32
|
||||
elif block_tile[1] == 64:
|
||||
warp_tile[1] = 32
|
||||
else :
|
||||
warp_tile[1] = 32
|
||||
|
||||
warp_tile[0] = 32
|
||||
|
||||
return block_tile, warp_tile
|
||||
|
||||
|
||||
def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp):
|
||||
epilogue_setted_type = epilogue_tp
|
||||
cutlass_epilogue_name = "LinearCombinationRelu"
|
||||
if epilogue_setted_type.lower() == 'leakyrelu':
|
||||
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
|
||||
elif epilogue_setted_type.lower() == 'identity':
|
||||
cutlass_epilogue_name = "LinearCombination"
|
||||
|
||||
|
||||
n_mod_8 = n % 4
|
||||
N_align_elements = 1
|
||||
if n_mod_8 == 0:
|
||||
N_align_elements = 8
|
||||
elif n_mod_8 == 4:
|
||||
N_align_elements = 4
|
||||
elif n_mod_8 == 2 or n_mod_8 == 6:
|
||||
N_align_elements = 2
|
||||
|
||||
epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">"
|
||||
|
||||
return epilogue_str
|
||||
|
||||
def gen_using(self, volta = True):
|
||||
code_using = ""
|
||||
volta_arch = "cutlass::arch::Sm70"
|
||||
volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>"
|
||||
|
||||
turing_arch = "cutlass::arch::Sm75"
|
||||
turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>"
|
||||
|
||||
arch = ""
|
||||
tc = ""
|
||||
if volta:
|
||||
arch = volta_arch
|
||||
tc = volta_tc
|
||||
else:
|
||||
arch = turing_arch
|
||||
tc = turing_tc
|
||||
|
||||
for i in range(self.b2b_num):
|
||||
|
||||
k = self.fuse_gemm_info[i]['mnk'][2]
|
||||
|
||||
k_mod_8 = k % 4
|
||||
ab_ldm = 1
|
||||
if k_mod_8 == 0:
|
||||
ab_ldm = 8
|
||||
elif k_mod_8 == 4:
|
||||
ab_ldm = 4
|
||||
elif k_mod_8 == 2 or k_mod_8 == 6:
|
||||
ab_ldm = 2
|
||||
|
||||
block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk'])
|
||||
|
||||
this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n"
|
||||
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n"
|
||||
this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n"
|
||||
this_gemm_config += " " + arch + ",\n"
|
||||
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n"
|
||||
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n"
|
||||
this_gemm_config += " " + tc + ",\n"
|
||||
this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n"
|
||||
this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n"
|
||||
this_gemm_config += " " + "2,\n"
|
||||
this_gemm_config += " " + str(ab_ldm) + ",\n"
|
||||
this_gemm_config += " " + str(ab_ldm) + ">;\n"
|
||||
|
||||
code_using += this_gemm_config + "\n"
|
||||
|
||||
return code_using + "\n"
|
||||
|
||||
def gen_initialize(self):
|
||||
code = ""
|
||||
for i in range(self.b2b_num):
|
||||
code_this = ""
|
||||
|
||||
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
|
||||
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
|
||||
beta = "(1)"
|
||||
if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False:
|
||||
beta = "(0)"
|
||||
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
|
||||
|
||||
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
|
||||
if i == 0:
|
||||
k_str = "K0"
|
||||
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
|
||||
code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
|
||||
code_this += " " + helper.var_idx("problem_size_", i) + ",\n"
|
||||
ldmA = k_str
|
||||
ldmB = k_str
|
||||
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
|
||||
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
|
||||
|
||||
if self.fuse_gemm_info[i]['A_format'] is 'Col':
|
||||
ldmA = "M"
|
||||
if self.fuse_gemm_info[i]['B_format'] is 'Row':
|
||||
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
|
||||
if self.fuse_gemm_info[i]['C_format'] is 'Col':
|
||||
ldmC = "M"
|
||||
|
||||
if i == 0:
|
||||
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
||||
else:
|
||||
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
|
||||
|
||||
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
|
||||
|
||||
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
|
||||
|
||||
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
|
||||
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n"
|
||||
code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
|
||||
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
|
||||
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
|
||||
code_this += " },\n"
|
||||
code_this += " " + "Batch};\n"
|
||||
|
||||
code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
|
||||
code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n"
|
||||
|
||||
code += code_this + "\n"
|
||||
return code + "\n"
|
||||
|
||||
|
||||
def gen_run(self):
|
||||
code = ""
|
||||
for i in range(self.b2b_num):
|
||||
code_this = ""
|
||||
code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n"
|
||||
|
||||
code += code_this
|
||||
return code
|
||||
|
||||
def gen_wrapper(self):
|
||||
code_body = ""
|
||||
|
||||
arg_lists = []
|
||||
arg_lists.append(["int", "M"])
|
||||
arg_lists.append(["int", "K0"])
|
||||
arg_lists.append(["int", "Batch"])
|
||||
arg_lists.append(["void*", helper.var_idx("A", 0)])
|
||||
for i in range(self.b2b_num):
|
||||
arg_lists.append(["void*", helper.var_idx("B", i)])
|
||||
arg_lists.append(["void*", helper.var_idx("C", i)])
|
||||
arg_lists.append(["void*", helper.var_idx("D", i)])
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_tp = arg[0]
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
||||
arg_lists.append([arg_tp, arg_name])
|
||||
code_body += self.gen_using()
|
||||
code_body += self.gen_initialize()
|
||||
code_body += self.gen_run()
|
||||
|
||||
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
|
||||
|
||||
return code
|
||||
|
||||
def gen_code(self):
|
||||
code = self.gen_wrapper()
|
||||
helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code)
|
||||
|
||||
class gen_one_API:
|
||||
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.gen_class_name = gen_class_name
|
||||
self.user_header_file = ""
|
||||
for header in user_header_file:
|
||||
self.user_header_file += "#include \"" + header + "\"\n"
|
||||
self.output_dir = output_dir
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
|
||||
self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
||||
|
||||
self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
||||
|
||||
def gen_CUTLASS_irrelevant_API(self):
|
||||
code = ""
|
||||
code += "#include <cuda_runtime.h>\n"
|
||||
code += "#include <assert.h>\n"
|
||||
|
||||
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
|
||||
for i in range(self.b2b_num):
|
||||
param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_"
|
||||
param_name += "Params"
|
||||
params = ""
|
||||
params += " " + "int M;\n"
|
||||
params += " " + "int K0;\n"
|
||||
params += " " + "int Batch;\n"
|
||||
params += " " + "const void* A0;\n"
|
||||
for i in range(self.b2b_num):
|
||||
params += " " + "const void* " + helper.var_idx("B", i) + ";\n"
|
||||
params += " " + "const void* " + helper.var_idx("C", i) + ";\n"
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_tp = arg[0]
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
||||
params += " " + arg_tp + " " + arg_name + ";\n"
|
||||
params += " " + "void* " + helper.var_idx("D", i) + ";\n"
|
||||
code += ir.gen_struct(param_name, params)
|
||||
code += "using Param = " + param_name + ";\n"
|
||||
code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n"
|
||||
|
||||
|
||||
return code
|
||||
|
||||
def gen_one_api(self):
|
||||
code = ""
|
||||
code += "/* Auto Generated code - Do not edit.*/\n"
|
||||
code += "#include \"cutlass_irrelevant.h\"\n"
|
||||
code += "#include \"api.h\"\n"
|
||||
code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n"
|
||||
|
||||
code += " " + "if (sm == 70) \n"
|
||||
code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
|
||||
for i in range(self.b2b_num):
|
||||
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
|
||||
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
|
||||
code += helper.var_idx("param.D", i) + ", "
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
||||
code += "param." + arg_name + ", "
|
||||
code += "stream);\n"
|
||||
code += " " + "else if(sm >= 75) \n"
|
||||
code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
|
||||
for i in range(self.b2b_num):
|
||||
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
|
||||
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
|
||||
code += helper.var_idx("param.D", i) + ", "
|
||||
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
|
||||
for arg in epilogue_args:
|
||||
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
|
||||
code += "param." + arg_name + ", "
|
||||
code += "stream);\n"
|
||||
code += " " + "else assert(0);\n"
|
||||
code += "}\n"
|
||||
return code
|
||||
|
||||
def gen_code(self):
|
||||
|
||||
turing_code = self.gen_turing.gen_wrapper()
|
||||
volta_code = self.gen_volta.gen_wrapper()
|
||||
cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API()
|
||||
|
||||
one_api_code = self.gen_one_api()
|
||||
with open(self.output_dir + "one_api.cu", "w+") as f:
|
||||
f.write(one_api_code)
|
||||
|
||||
helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code)
|
||||
|
||||
helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code)
|
||||
92
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py
Normal file
92
examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py
Normal file
@ -0,0 +1,92 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import helper
|
||||
import gen_ir as ir
|
||||
|
||||
import gen_turing_and_volta as gen_basic
|
||||
|
||||
|
||||
class gen_verify:
|
||||
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
|
||||
self.fuse_gemm_info = fuse_gemm_info
|
||||
self.name = gen_class_name + "_verify"
|
||||
self.b2b_num = len(fuse_gemm_info)
|
||||
self.params = []
|
||||
self.user_header_file = ""
|
||||
for header in user_header_file:
|
||||
self.user_header_file += "#include \"" + header + "\"\n"
|
||||
self.seperate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
|
||||
self.gen_params()
|
||||
self.output_dir = output_dir
|
||||
|
||||
|
||||
def gen_code(self):
|
||||
code = ""
|
||||
code += self.user_header_file
|
||||
code += self.seperate_cutlass.gen_using(False) #False -> Turing, True -> Volta
|
||||
|
||||
code_body = ""
|
||||
for i in range(self.b2b_num):
|
||||
code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
|
||||
code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n"
|
||||
|
||||
code_body += self.seperate_cutlass.gen_run()
|
||||
|
||||
code += ir.gen_func(self.name, self.params, code_body)
|
||||
helper.write_2_headfile("cutlass_verify.h", self.output_dir, code)
|
||||
|
||||
|
||||
def gen_params(self):
|
||||
for i in range(self.b2b_num):
|
||||
self.params.append(
|
||||
(
|
||||
helper.var_idx("typename Gemm", i)+ "::Arguments",
|
||||
helper.var_idx("Arguments_", i)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_params(self, declartion = True):
|
||||
code = ""
|
||||
if declartion:
|
||||
for param in self.params:
|
||||
code += param[0] + " " + param[1] + ";\n"
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def gen_initialize():
|
||||
code = ""
|
||||
initialize_code = self.seperate_cutlass.gen_initialize()
|
||||
|
||||
code = ir.gen_func("initialize", [[]])
|
||||
52
examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh
Executable file
52
examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh
Executable file
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
NUM_ARGS=3
|
||||
if [ $# -ne $NUM_ARGS ]; then
|
||||
echo "Usage: $0 <config_file> <output_directory> <cutlass_directory>"
|
||||
echo " config_file: JSON file containing configuration to run"
|
||||
echo " output_directory: directory to store results"
|
||||
echo " cutlass_directory: directory containing cutlass source"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
config_file=$1
|
||||
output_dir=$2
|
||||
cutlass_dir=$3
|
||||
|
||||
python3 gen_all_code.py \
|
||||
--config-file $config_file \
|
||||
--gen-name FusedMultiGemmForward \
|
||||
--output-dir $output_dir \
|
||||
--cutlass-dir $cutlass_dir
|
||||
135
examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py
Normal file
135
examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py
Normal file
@ -0,0 +1,135 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
def type_2_cutlass_type(input_type = "fp16"):
|
||||
# float point type
|
||||
if input_type == "fp32":
|
||||
return "float"
|
||||
if input_type == "bf16":
|
||||
return "cutlass::bfloat16_t"
|
||||
if input_type == "fp16":
|
||||
return "cutlass::half_t"
|
||||
|
||||
# integer type
|
||||
if(input_type == "int32"):
|
||||
return "int32_t"
|
||||
if(input_type == "int8"):
|
||||
return "int8_t"
|
||||
|
||||
if input_type == 'Row':
|
||||
return 'cutlass::layout::RowMajor'
|
||||
if input_type == 'Col':
|
||||
return 'cutlass::layout::ColumnMajor'
|
||||
|
||||
def cvt_2_cutlass_shape(gemm_shape):
|
||||
# gemm shape
|
||||
if len(gemm_shape) == 3:
|
||||
val = "cutlass::gemm::GemmShape<" \
|
||||
+ str(gemm_shape[0]) + ", " \
|
||||
+ str(gemm_shape[1]) + ", " \
|
||||
+ str(gemm_shape[2]) + ">"
|
||||
return val
|
||||
|
||||
|
||||
def write_2_headfile(filename, file_dir, string):
|
||||
with open(file_dir + filename, 'w') as f:
|
||||
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
|
||||
|
||||
def var_idx(varaiable, index):
|
||||
return varaiable + str(index)
|
||||
|
||||
|
||||
def list_2_string(input_list, ):
|
||||
rtn_string = ""
|
||||
|
||||
cnt = 0
|
||||
|
||||
for element in input_list:
|
||||
final = ", \n"
|
||||
if cnt == len(input_list) - 1:
|
||||
final = "\n"
|
||||
cnt += 1
|
||||
rtn_string += str(element) + final
|
||||
|
||||
return rtn_string
|
||||
|
||||
|
||||
def get_epilouge_info(layer_info):
|
||||
return layer_info['epilogue']
|
||||
|
||||
def get_epilogue_tp(layer_info):
|
||||
epilogue_info = get_epilouge_info(layer_info)
|
||||
return epilogue_info['tp']
|
||||
|
||||
def get_epilogue_add_bias_or_not(layer_info):
|
||||
epilogue_info = get_epilouge_info(layer_info)
|
||||
return epilogue_info['bias']['addbias']
|
||||
|
||||
def get_epilogue_add_bias_tp(layer_info):
|
||||
epilogue_info = get_epilouge_info(layer_info)
|
||||
return epilogue_info['bias']['bias_tp']
|
||||
|
||||
def get_epilogue_args(layer_info):
|
||||
epilogue_info = get_epilouge_info(layer_info)
|
||||
return epilogue_info['args']
|
||||
|
||||
def get_epilogue_bias_shape(layer_info):
|
||||
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
|
||||
mn_shape = layer_info['mnk'][:-1]
|
||||
|
||||
if bias_tp == 'mat':
|
||||
mn_shape[0] = 'M'
|
||||
return mn_shape
|
||||
elif bias_tp == 'vec':
|
||||
mn_shape[0] = 1
|
||||
return mn_shape
|
||||
else:
|
||||
assert(0)
|
||||
|
||||
def get_epilogue_bias_ldm(layer_info):
|
||||
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
|
||||
mn_shape = layer_info['mnk'][:-1]
|
||||
|
||||
c_layout = layer_info['C_format'].lower()
|
||||
|
||||
if c_layout != 'row':
|
||||
assert(0)
|
||||
|
||||
if bias_tp == 'mat':
|
||||
return mn_shape[1]
|
||||
elif bias_tp == 'vec':
|
||||
return 0
|
||||
else:
|
||||
assert(0)
|
||||
|
||||
def get_epilogue_compute_tp(layer_info):
|
||||
return layer_info['Acc_tp']
|
||||
@ -0,0 +1,67 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import os
|
||||
|
||||
class replace_fix_impl:
|
||||
def __init__(self, src_dir, dst_dir, cutlass_deps_root):
|
||||
self.src_dir = src_dir
|
||||
self.dst_dir = dst_dir
|
||||
self.cutlass_deps_root = cutlass_deps_root
|
||||
|
||||
|
||||
|
||||
def gen_code(self):
|
||||
for sub_dir in os.walk(self.src_dir):
|
||||
files_in_sub_dir = sub_dir[2]
|
||||
|
||||
src_dirs = sub_dir[0]
|
||||
output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):]
|
||||
|
||||
if not os.path.exists(output_dirs):
|
||||
os.mkdir(output_dirs)
|
||||
|
||||
for f in files_in_sub_dir:
|
||||
with open(src_dirs +"/" + f, 'r') as current_file:
|
||||
output_lines = []
|
||||
lines = current_file.readlines()
|
||||
|
||||
for line in lines:
|
||||
if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"):
|
||||
new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):]
|
||||
# print(new_line)
|
||||
output_lines.append(new_line)
|
||||
else:
|
||||
output_lines.append(line)
|
||||
|
||||
with open(output_dirs + "/" + f, "w+") as dest_file:
|
||||
dest_file.writelines(output_lines)
|
||||
292
examples/44_multi_gemm_ir_and_codegen/leaky_bias.h
Normal file
292
examples/44_multi_gemm_ir_and_codegen/leaky_bias.h
Normal file
@ -0,0 +1,292 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
template <typename T>
|
||||
__device__
|
||||
T add(T const & a, T const &b){
|
||||
return (a + b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__
|
||||
half2 add(half2 const & a, half2 const &b){
|
||||
return (__hadd2(a,b));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RELU{
|
||||
__device__
|
||||
T operator()(T const & a){
|
||||
return a > T(0) ? a : T(0);
|
||||
}
|
||||
__device__
|
||||
half2 operator()(half2 const & a){
|
||||
float2 a_fp32x2 = __half22float2(a);
|
||||
a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f;
|
||||
a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f;
|
||||
if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f)
|
||||
printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y);
|
||||
return __float22half2_rn(a_fp32x2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LEAKY_RELU{
|
||||
__device__
|
||||
T operator()(T const & a, T const & scale = half(1)){
|
||||
return a > T(0) ? a : scale * a;
|
||||
}
|
||||
__device__
|
||||
half2 operator()(half2 const & a, half const & scale = half(1)){
|
||||
half2 zero = __half2half2(half(0));
|
||||
half2 gt_zero = __hge2(a, zero);
|
||||
half2 le_zero = __hle2(a, zero);
|
||||
|
||||
|
||||
half2 scale_f16x2 = __half2half2(scale);
|
||||
half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero);
|
||||
return __hmul2(a, mask_scale_f16x2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
LEAKY_RELU<half> Act;
|
||||
Access_tp src_v[iter];
|
||||
Access_tp bias_v[iter];
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
if (mat_bias)
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
||||
else
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void leaky_and_activation(half* inout, half scale){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
|
||||
LEAKY_RELU<half> Act;
|
||||
Access_tp src_v[iter];
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){
|
||||
|
||||
dim3 grid(m, b);
|
||||
if (bias == nullptr)
|
||||
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, scale);
|
||||
else
|
||||
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, scale, mat_bias);
|
||||
}
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
RELU<half> Act;
|
||||
Access_tp src_v[iter];
|
||||
Access_tp bias_v[iter];
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
if (mat_bias)
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
||||
else
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void relu_and_activation(half* inout){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
|
||||
RELU<half> Act;
|
||||
Access_tp src_v[iter];
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
|
||||
dim3 grid(m, b);
|
||||
if (bias == nullptr)
|
||||
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
|
||||
else
|
||||
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
|
||||
}
|
||||
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
|
||||
Access_tp src_v[iter];
|
||||
Access_tp bias_v[iter];
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
if (mat_bias)
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
|
||||
else
|
||||
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
__global__ void identity_and_activation(half* inout){
|
||||
|
||||
constexpr bool N_MOD_2 = N & 1 ? false : true;
|
||||
|
||||
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
|
||||
|
||||
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
|
||||
|
||||
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
|
||||
|
||||
int batch_id = blockIdx.y;
|
||||
int batch_offset = batch_id * gridDim.x * N;
|
||||
Access_tp src_v[iter];
|
||||
|
||||
for(int i = 0; i < iter; i++){
|
||||
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
|
||||
if (idx < N){
|
||||
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
|
||||
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, int BLOCKDIM>
|
||||
void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
|
||||
dim3 grid(m, b);
|
||||
if (bias == nullptr)
|
||||
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
|
||||
else
|
||||
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
|
||||
}
|
||||
94
examples/44_multi_gemm_ir_and_codegen/utils.h
Normal file
94
examples/44_multi_gemm_ir_and_codegen/utils.h
Normal file
@ -0,0 +1,94 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#define TI(tag) \
|
||||
cudaEvent_t _event_start_ ##tag; \
|
||||
cudaEvent_t _event_end_ ##tag; \
|
||||
float _event_time_ ##tag; \
|
||||
cudaEventCreate(& _event_start_ ##tag); \
|
||||
cudaEventCreate(& _event_end_ ##tag); \
|
||||
cudaEventRecord(_event_start_ ##tag);
|
||||
|
||||
#define TO(tag, str, times) \
|
||||
cudaEventRecord(_event_end_ ##tag); \
|
||||
cudaEventSynchronize(_event_end_ ##tag); \
|
||||
cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \
|
||||
float _event_time_once_ ##tag = _event_time_ ##tag / times; \
|
||||
printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \
|
||||
cudaDeviceSynchronize(); \
|
||||
printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError()));
|
||||
|
||||
template<typename T>
|
||||
struct memory_unit{
|
||||
T* host_ptr;
|
||||
T* device_ptr;
|
||||
int size_bytes;
|
||||
int elements;
|
||||
void h2d(){
|
||||
cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice);
|
||||
}
|
||||
void d2h(){
|
||||
cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost);
|
||||
}
|
||||
void free_all(){
|
||||
free(host_ptr);
|
||||
cudaFree(device_ptr);
|
||||
}
|
||||
memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){
|
||||
host_ptr = (T*) malloc(elements_ * sizeof(T));
|
||||
cudaMalloc((void**)&device_ptr, elements_ * sizeof(T));
|
||||
}
|
||||
void init(int abs_range = 1){
|
||||
for(int i = 0; i < elements; i++){
|
||||
host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range);
|
||||
}
|
||||
h2d();
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
int check_result(T * a, T * b, int N){
|
||||
int cnt = 0;
|
||||
for(int i = 0; i < N; i ++){
|
||||
float std = float(a[i]);
|
||||
float my = float(b[i]);
|
||||
|
||||
if(abs(std - my) / abs(std) > 1e-2)
|
||||
{
|
||||
// printf("my: %f , std: %f\n", my, std);
|
||||
cnt++;
|
||||
}
|
||||
|
||||
}
|
||||
printf("total err: %d / %d\n", cnt, N);
|
||||
return cnt;
|
||||
}
|
||||
@ -30,7 +30,7 @@
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
43_dual_gemm
|
||||
45_dual_gemm
|
||||
dual_gemm.cu
|
||||
)
|
||||
|
||||
36
examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt
Normal file
36
examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
46_depthwise_simt_conv2dfprop
|
||||
depthwise_simt_conv2dfprop.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,672 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
This example shows how to run depthwise 2d convolution kernels using functions and data structures
|
||||
provided by CUTLASS using SIMT instruction;
|
||||
|
||||
There are 3 types of implementations of depthwise 2d convoltion
|
||||
1. kAnalytic
|
||||
Implicit gemm 2d convoltion algorithm.
|
||||
2. kOptimized
|
||||
An optimized algorithm and supports arbitrary stride and dilation.
|
||||
3. kFixedStrideDilation
|
||||
An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do
|
||||
more optimizations.
|
||||
|
||||
In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter
|
||||
size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If
|
||||
in this case, please use kOptimized.
|
||||
|
||||
For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve
|
||||
better perf, when the output tensor size is large, splitk should be enabled to achieve better perf.
|
||||
|
||||
In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d
|
||||
convolution kernel.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
#include "cutlass/conv/device/direct_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = cutlass::half_t; // Data type of accumulator
|
||||
using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassSimt;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm60;
|
||||
|
||||
// This code section describes the groups a thread block will compute
|
||||
constexpr int groups_per_cta = 64;
|
||||
|
||||
// This code section describes the output tile <N, O, P, Q> a thread block will compute
|
||||
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
|
||||
|
||||
// This code section describes the filter shape <R, S>
|
||||
using FilterShape = cutlass::MatrixShape<3, 3>;
|
||||
|
||||
// Threadblock tile shape
|
||||
using ThreadblockShape =
|
||||
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
|
||||
|
||||
// This code section describes tile size a warp will computes
|
||||
// WarpShape::kM = P * Q the warps would process
|
||||
// WarpShape::kN = groups_per_cta that the warps would process
|
||||
// WarpShape::kK = filter_size that the warps would process
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock =
|
||||
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
||||
1,
|
||||
ThreadBlockOutputShape::kN,
|
||||
ThreadBlockOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
// This code section describe iterator algorithm selected is kFixedStrideDilation
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
|
||||
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
|
||||
using StrideShape = cutlass::MatrixShape<1, 1>;
|
||||
using DilationShape = cutlass::MatrixShape<1, 1>;
|
||||
|
||||
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
kEpilogueElementsPerAccess, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation.
|
||||
|
||||
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
ThreadBlockOutputShape,
|
||||
FilterShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm,
|
||||
cutlass::conv::StrideSupport::kFixed,
|
||||
StrideShape,
|
||||
DilationShape>::Kernel;
|
||||
|
||||
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
int groups;
|
||||
int splitk;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
std::string tag;
|
||||
|
||||
Options()
|
||||
: help(false),
|
||||
input_size(1, 128, 128, 32),
|
||||
filter_size(32, 3, 3, 1),
|
||||
groups(32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
splitk(1) {}
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) {
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// depthwise conv
|
||||
if (groups != input_size.c()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (filter_size.n() != groups) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) {
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
|
||||
cmd.get_cmd_line_argument("g", groups);
|
||||
|
||||
filter_size.c() = 1;
|
||||
filter_size.n() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("splitk", splitk);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
int32_t padding_h = filter_size.h() / 2;
|
||||
int32_t padding_w = filter_size.w() / 2;
|
||||
padding = {padding_h, padding_h, padding_w, padding_w};
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream &print_usage(std::ostream &out) const {
|
||||
out << "41_depthwise_gemm_fprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --g=<int> Groups\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --splitk=<int> Enable splitK\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag=<string> String to replicate across the first column in the results "
|
||||
"table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 "
|
||||
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
|
||||
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 "
|
||||
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas =
|
||||
output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result()
|
||||
: runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) {}
|
||||
|
||||
static std::ostream &print_header(std::ostream &out, Options const &options) {
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream &print(std::ostream &out, int idx, Options const &options) {
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
cutlass::Tensor4DCoord output_size = options.output_size();
|
||||
out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << ","
|
||||
<< options.input_size.w() << "," << options.input_size.c() << ","
|
||||
|
||||
<< options.filter_size.n() << "," << options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
|
||||
<< options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column()
|
||||
<< ","
|
||||
|
||||
<< options.dilation.row() << "," << options.dilation.column() << ","
|
||||
|
||||
<< options.splitk << ","
|
||||
|
||||
<< runtime_ms << "," << gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one testcase
|
||||
Result profile_convolution(Options const &options) {
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b_transpose(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0);
|
||||
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0);
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(tensor_d.host_view());
|
||||
|
||||
// Fill tensor D for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(tensor_ref_d.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_b_transpose.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split P*Q into multiple CTA
|
||||
int split_k_slices = options.splitk;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices,
|
||||
options.groups);
|
||||
|
||||
// Construct Direc2dConv::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename Direct2dConv::Arguments arguments{problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
tensor_b_transpose.device_ref()};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
Direct2dConv implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue> >(problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref(),
|
||||
options.alpha,
|
||||
options.beta);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_d.sync_host();
|
||||
|
||||
bool passed =
|
||||
cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
|
||||
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
|
||||
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace << "Input = \n"
|
||||
<< tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n"
|
||||
<< tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto &event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error)
|
||||
<< std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
bool notSupported = false;
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major >= 6)) {
|
||||
std::cerr << "Run on a machine with compute capability at least 60." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -119,9 +119,11 @@ foreach(EXAMPLE
|
||||
37_gemm_layernorm_gemm_fusion
|
||||
38_syr2k_grouped
|
||||
39_gemm_permute
|
||||
41_multi_head_attention
|
||||
42_fused_multi_head_attention
|
||||
43_dual_gemm
|
||||
41_fused_multi_head_attention
|
||||
42_ampere_tensorop_group_conv
|
||||
43_ell_block_sparse_gemm
|
||||
45_dual_gemm
|
||||
46_depthwise_simt_conv2dfprop
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -80,9 +80,9 @@ public:
|
||||
typedef value_type *pointer;
|
||||
typedef value_type const * const_pointer;
|
||||
|
||||
using ArrayType = Array<T, N>;
|
||||
using reference = typename ArrayType::reference;
|
||||
using const_reference = typename ArrayType::const_reference;
|
||||
using Array = Array<T, N>;
|
||||
using reference = typename Array::reference;
|
||||
using const_reference = typename Array::const_reference;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
@ -85,6 +85,10 @@ struct Sm86 {
|
||||
static int const kMinComputeCapability = 86;
|
||||
};
|
||||
|
||||
struct Sm90 {
|
||||
static int const kMinComputeCapability = 90;
|
||||
};
|
||||
|
||||
/// Triggers a breakpoint on the device
|
||||
CUTLASS_DEVICE
|
||||
void device_breakpoint() {
|
||||
|
||||
@ -451,7 +451,7 @@ template <>
|
||||
CUTLASS_DEVICE
|
||||
void shared_store<16>(uint32_t ptr, void const *src) {
|
||||
uint4 const *dst_u128 = reinterpret_cast<uint4 const *>(src);
|
||||
asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
: :
|
||||
"r"(ptr),
|
||||
"r"(dst_u128->x),
|
||||
|
||||
@ -223,4 +223,6 @@ struct SparseMma;
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sparse_sm80.h"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1065,7 +1065,7 @@ struct Mma<
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
asm volatile("_mma.m8n8k32.row.col.u4.s4.sat {%0,%1}, %2, %3, {%4,%5};\n"
|
||||
: "=r"(D[0]), "=r"(D[1])
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
@ -1247,7 +1247,8 @@ struct Mma<
|
||||
) const {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED))
|
||||
using WmmaFragmentA = nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::matrix_a,
|
||||
Shape::kM,
|
||||
@ -1279,6 +1280,7 @@ struct Mma<
|
||||
|
||||
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
@ -1289,14 +1291,7 @@ struct Mma<
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -2156,6 +2156,7 @@ struct Mma<
|
||||
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
|
||||
"{%4,%5,%6,%7}, "
|
||||
|
||||
131
include/cutlass/arch/mma_sm90.h
Normal file
131
include/cutlass/arch/mma_sm90.h
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Matrix multiply
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8))
|
||||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#define CUTLASS_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix Multiply-Add 16x8x4 fp64
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
||||
template <>
|
||||
struct Mma<
|
||||
gemm::GemmShape<16,8,4>,
|
||||
32,
|
||||
double,
|
||||
layout::RowMajor,
|
||||
double,
|
||||
layout::ColumnMajor,
|
||||
double,
|
||||
layout::RowMajor,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<16,8,4>;
|
||||
|
||||
using ElementA = double;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using FragmentA = Array<double, 2>;
|
||||
|
||||
using ElementB = double;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using FragmentB = Array<double, 1>;
|
||||
|
||||
using ElementC = double;
|
||||
using LayoutC = layout::RowMajor;
|
||||
using FragmentC = Array<double, 4>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
using ArchTag = arch::Sm90;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
||||
FragmentC const &c) const {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED)
|
||||
|
||||
double const *A = reinterpret_cast<double const *>(&a);
|
||||
double const *B = reinterpret_cast<double const *>(&b);
|
||||
|
||||
double const *C = reinterpret_cast<double const *>(&c);
|
||||
double *D = reinterpret_cast<double *>(&d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
|
||||
: "d"(A[0]), "d"(A[1]),
|
||||
"d"(B[0]),
|
||||
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(d);
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
201
include/cutlass/barrier.h
Normal file
201
include/cutlass/barrier.h
Normal file
@ -0,0 +1,201 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// CTA-wide semaphore for inter-CTA synchronization.
|
||||
struct Barrier
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
/// Flag type
|
||||
using T = int;
|
||||
|
||||
/// Initial flag value
|
||||
static const T INIT = 0;
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
/// Load flag, as a strong operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static int ld_strong(int *ptr)
|
||||
{
|
||||
int state = 0;
|
||||
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
#else
|
||||
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// Store flag, as a strong operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void st_strong(int *ptr, int val)
|
||||
{
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#else
|
||||
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
}
|
||||
|
||||
|
||||
/// Reduce into flag, with release pattern (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void red_release(int *ptr, int val)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#else
|
||||
__threadfence();
|
||||
atomicAdd(ptr, val);
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) < count) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) != val) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Uses thread[0] to wait for the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(atomicCAS(flag_ptr, val, 0) != val) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Increment the arrival count for a flag
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (thread_idx == 0) {
|
||||
red_release(flag_ptr, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/// Increment the arrival counts for a range of flags
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
int flag_idx = first_flag_idx + thread_idx;
|
||||
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
// Barrier to make sure all other threads in block have written their data
|
||||
__syncthreads();
|
||||
|
||||
// Select threads increment their flags
|
||||
if (thread_idx < count) {
|
||||
red_release(flag_ptr, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user