releaase 2.11 (#703)

This commit is contained in:
Aditya Atluri
2022-11-19 06:02:15 -08:00
committed by GitHub
parent 3c90f6aea6
commit c975e2ccbb
329 changed files with 47332 additions and 10607 deletions

18
.github/labeler.yml vendored
View File

@ -1,18 +0,0 @@
# https://github.com/actions/labeler#common-examples
examples:
- examples/**
source:
- cmake/**
- include/cutlass/**
documentation:
- docs/**
- media/**
testing:
- test/**
tooling:
- tools/**

View File

@ -5,8 +5,7 @@ on:
jobs:
triage:
runs-on: ubuntu-latest
permissions: read-all|write-all
steps:
- uses: actions/labeler@master
- uses: actions/labeler@main
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -1,5 +1,27 @@
# NVIDIA CUTLASS Changelog
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
* [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
* [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
* Maxwell and Pascal GPU architectures
* Ubuntu 16.04
* CUDA 10.2
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
@ -16,11 +38,6 @@
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
* Maxwell and Pascal GPU architectures
* Ubuntu 16.04
* CUDA 10.2
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment

View File

@ -73,10 +73,10 @@ abstract: >-
keywords:
- 'cutlass, tensor cores, cuda'
license: BSD-3-Clause
license-url: https://github.com/NVIDIA/cutlass/blob/v2.10.0/LICENSE.txt
version: '2.10.0'
date-released: '2022-09-15'
license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt
version: '2.11.0'
date-released: '2022-11-19'
identifiers:
- type: url
value: "https://github.com/NVIDIA/cutlass/tree/v2.10.0"
description: The GitHub release URL of tag 2.10.0
value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0"
description: The GitHub release URL of tag 2.11.0

View File

@ -38,7 +38,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.10.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.11.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 10.2)
@ -87,6 +87,7 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
@ -122,6 +123,9 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
@ -569,6 +573,9 @@ install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest)
################################################################################
set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "cuBLAS usage for tests")
set(CUTLASS_ENABLE_CUDNN OFF CACHE BOOL "cuDNN usage for tests")
include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake)
if (CUTLASS_ENABLE_CUBLAS)
@ -732,7 +739,7 @@ if (CUTLASS_ENABLE_TOOLS)
add_subdirectory(tools)
if (CUTLASS_ENABLE_PROFILER)
add_dependencies(test_all test_profiler)
endif()
endif()
endif()
if (CUTLASS_ENABLE_EXAMPLES)
add_subdirectory(examples)

View File

@ -7,10 +7,10 @@
This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
Manish Gupta
Dustyn Blasig
Andrew Kerr
Haicheng Wu
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Cris Cecka
Vijay Thakkar
@ -20,52 +20,50 @@ Ethan Yan
Zhaodong Chen
Jack Kosaian
Yujia Zhai
Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Chinmay Talegaonkar
Shang Zhang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner
Manikandan Ananth
Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Chinmay Talegaonkar
Shang Zhang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner
Manikandan Ananth
## CUTLASS Product Manager
Matthew Nicely
## CONTRIBUTORS
Timothy Costa
Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Duane Merrill
Kevin Siu
Markus Tavenrath
John Tran
Vicki Wang
Junkai Wu
Fung Xie
Albert Xu
Jack Yang
Xiuxia Zhang
Nick Zhao
Timothy Costa
Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Duane Merrill
Kevin Siu
Markus Tavenrath
John Tran
Vicki Wang
Junkai Wu
Fung Xie
Albert Xu
Jack Yang
Xiuxia Zhang
Nick Zhao
## ACKNOWLEDGEMENTS
Girish Bharambe
Luke Durant
Olivier Giroux
Stephen Jones
Rishkul Kulkarni
Bryce Lelbach
Joel McCormack
Kyrylo Perelygin
Girish Bharambe
Luke Durant
Olivier Giroux
Stephen Jones
Rishkul Kulkarni
Bryce Lelbach
Joel McCormack
Kyrylo Perelygin

100
README.md
View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.10
# CUTLASS 2.11
_CUTLASS 2.10 - August 2022_
_CUTLASS 2.11 - November 2022_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) and related computations at all levels
@ -36,21 +36,21 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.10
# What's New in CUTLASS 2.11
CUTLASS 2.10 is an update to CUTLASS adding:
- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, Convolution and Grouped GEMM for different data types as well as different epilogue flavors.
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable.
- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax).
- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM.
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after.
- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing.
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized.
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now.
- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements.
- Updates and bugfixes from the community (thanks!)
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
CUTLASS 2.11 is an update to CUTLASS adding:
- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm).
- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv).
- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop).
- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM.
- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release:
- Maxwell and Pascal GPU architectures
- Ubuntu 16.04
- CUDA 10.2
@ -80,10 +80,11 @@ as shown in the above figure. Tensor Core operations are still implemented usin
# Compatibility
CUTLASS requires a C++11 host compiler and
performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5.
CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.x.
## Operating Systems
We have tested the following environments.
|**Operating System** | **Compiler** |
@ -93,11 +94,12 @@ We have tested the following environments.
| | Microsoft Visual Studio 2019|
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 21.04 | GCC 11.2.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.
## Hardware
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
@ -110,9 +112,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|NVIDIA A100|8.0|11.0|11.0|
|NVIDIA A10 |8.6|11.1|11.1|
|NVIDIA GeForce 3090|8.6|11.1|11.1|
For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit)
for best performance.
|NVIDIA H100 PCIe|9.0|11.8|11.8|
# Documentation
@ -133,9 +133,16 @@ CUTLASS is described in the following documents and the accompanying
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
# Resources
We have also described the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
# Building CUTLASS
CUTLASS is a header-only template library and does not need to be built to be used by other
@ -199,6 +206,8 @@ include/ # client applications should target this directory
conv/ # code specialized for convolution
epilogue/ # code specialized for the epilogue of gemm/convolution
gemm/ # code specialized for general matrix product computations
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
@ -206,6 +215,8 @@ include/ # client applications should target this directory
platform/ # CUDA-capable Standard Library components
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
thread/ # simt code that can be performed within a CUDA thread
transform/ # code specialized for layout, type, and domain transformations
@ -216,49 +227,6 @@ include/ # client applications should target this directory
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
```
examples/
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
31_basic_syrk # example demonstrating Symmetric Rank-K update
32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication
33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
40_cutlass_py # example demonstrating CUTLASS with CUDA Python
```
### Tools
```

View File

@ -29,7 +29,6 @@
set(TEST_COMMAND_00 RowMajor --extent=16,16)
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
cutlass_example_add_executable(
03_visualize_layout
@ -37,6 +36,5 @@ cutlass_example_add_executable(
register_layout.cu
TEST_COMMAND_OPTIONS
TEST_COMMAND_00
TEST_COMMAND_01
)

View File

@ -64,15 +64,15 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
// All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling
// layout implementation with different templates.
//
// BMMA 88128 Interleaved-256
// BMMA 168256 Interleaved-256
// mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256
{"TensorOpMultiplicand<1,256>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 256>>},
// BMMA 88128 TN kblock512
// BMMA 168256 TN kblock512
// mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512
{"TensorOpMultiplicand<1,512>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 512>>},
// BMMA 168256 TN kblock1024
// mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024
{"TensorOpMultiplicand<1,1024>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 1024>>},
// Integer matrix multiply.int4 8832 Interleaved-64

View File

@ -81,7 +81,7 @@ matrix A can be seen as
---------------------------------------
batch 0 | batch 1
, where batch size is 2, M is 6 and K is 2
The stride (batch_stride_B) between the first element of two batches is lda * k
The stride (batch_stride_A) between the first element of two batches is lda * k
matrix B can be seen as
-----------------------------
@ -94,7 +94,7 @@ matrix B can be seen as
(1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------
, where the batch size is 2, N is 3 and K is 2
The stride (batch_stride_C) between the first element of two batches is k
The stride (batch_stride_B) between the first element of two batches is k
*/

View File

@ -426,7 +426,7 @@ Result profile(Options const &options) {
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments(
mode,
options.problem_size,
batch_count,
@ -445,8 +445,7 @@ Result profile(Options const &options) {
tensor_b.layout().stride(0),
tensor_c.layout().stride(0),
tensor_d.layout().stride(0),
tensor_reduction.layout().stride(0)
};
tensor_reduction.layout().stride(0));
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
@ -515,15 +514,14 @@ Result profile(Options const &options) {
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_nullptr_tensorref(nullptr, splitk_vector_layout);
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments(
cutlass::MatrixCoord(1, reduce_vector_length),
batch_count,
size_t(reduce_vector_length),
workspace_vector_tensorref,
tensor_reduction_tensorref,
tensor_nullptr_tensorref,
{1.0f, 0.0f}
};
{1.0f, 0.0f});
ReduceVectorSplitK reduce_vector_splitk_op;

View File

@ -531,17 +531,17 @@ Result profile_convolution(Options const &options) {
// Reduction input
{
reinterpret_cast<ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
// Destination
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
// Source
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
{options.alpha, options.beta}
);

View File

@ -367,12 +367,6 @@ public:
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM

View File

@ -309,12 +309,6 @@ public:
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -690,7 +690,7 @@ public:
// Initialize the GEMM object
GemmBatched gemm;
result.status = gemm.initialize(arguments);
result.status = gemm.initialize(arguments, nullptr);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;
@ -854,7 +854,7 @@ public:
// Initialize the GEMM object
GemmPermute gemm_normal;
result.status = gemm_normal.initialize(arguments);
result.status = gemm_normal.initialize(arguments, nullptr);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl;

View File

@ -1,230 +1,23 @@
# CUTLASS Python Interface Example
# CUTLASS Python Interface Examples
This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples:
* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations
* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel
>>>>>>> Add simplified examples
## Using Docker
You can run the PyCUTLASS on NGC PyTorch container.
## Setting up the Python interface
Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API.
## Running examples
Each of the basic examples can be run as follows:
```shell
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3
```
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
```bash
apt-get update
apt-get -y install libboost-all-dev
# Run the GEMM example
python gemm.py
# Run the Conv2d example
python conv2d.py
# Run the grouped GEMM example
python gemm_grouped.py
```
## Install the Python Interface
The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables:
* `CUTLASS_PATH`: the root directory of CUTLASS
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
After setting these two environment variables, PyCUTLASS can be installed with
```shell
cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
```
***
## Troubleshooting
### Issue 1: permission denied
Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission.
### Issue 2: rmm: module not found
PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps:
```shell
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python
python setup.py build_ext --inplace
python setup.py install
```
To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues.
***
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
## GEMM Examples
The GEMM examples use numpy to create input tensors and verify the results.
### GEMM F64 Example
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
### GEMM F32 Example
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
```
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
```
### GEMM F16 Example
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
```
### GEMM BF16 Example
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
```
### GEMM Int8 Example
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
```python
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
```
### Batched & Array GEMM
Example 1: Batched GEMM
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 2: Array GEMM
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
```
***
## GEMM Grouped Examples
The GEMM Grouped examples use numpy to create input tensors and verify the results.
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
```
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
```
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
```python
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
***
## Conv2d Example
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences).
### Conv2d F32 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
```
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
```
### Conv2d F32 Wgrad
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F32 Dgrad
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F16 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
## Epilogue
### Bias
To replace C with a bias vector, add `-bias` flag.
### Activation function
Example 1: ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
```
Example 2: leaky ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
```
Example 3: tanh (alpha=0 to avoid saturation)
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
```
Example 4: sigmoid
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
```
Example 5: SiLU
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
```
Example 6: HardSwish
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
```
Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
### Epilogue Visitor Tree
Example 1:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2:
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 3:
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 4:
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 5:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 6:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
```
To run the customizable examples, refer to the README in the [customizable](customizable) directory.

View File

@ -29,290 +29,133 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
import torch.nn.functional as F
"""
Basic example of using the CUTLASS Python interface to run a 2d convolution
"""
import argparse
import torch
import numpy as np
import sys
# parse the arguments
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorC32RSK32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'],
help='Data type of computation in the epilogue')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
"StridedDgradHorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# conv related
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
help="The type of convolution: forward propagation (fprop), \
gradient of activation (dgrad), gradient of weight (wgrad)")
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
)
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
help="This option describes iterator algorithm")
# arguments
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
Parallel is used for parallel splitK.")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
import cutlass
import pycutlass
from pycutlass import *
import util
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
parser = argparse.ArgumentParser(
description=("Launch a 2d convolution kernel from Python. "
"See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation."))
parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution")
parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution")
parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution")
parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution")
parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution")
parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution")
parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution")
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
# Check that the device is of a sufficient compute capability
cc = util.get_device_cc()
assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70."
alignment = 1
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
# Allocate a pool of device memory to be used by the kernel
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
# Set the compiler to use to NVCC
pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
[16, 8, 8], # Shape of the Tensor Core instruction
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
MathOperation.multiply_add
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
[128, 128, 32], # Threadblock shape
2, # Number of stages
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if (args.activation_function == "identity"
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, stride_support=stride_support,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
conv_kind=cutlass.conv.Operator.fprop,
iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
arch=cc, tile_description=tile_description,
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor
)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation,]
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
operations = [operation, ]
# Compile the operation
pycutlass.compiler.add_module(operations)
# Randomly initialize tensors
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w),
cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c),
cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding
cutlass.MatrixCoord(1, 1), # Strides
cutlass.MatrixCoord(1, 1), # Dilation
cutlass.conv.Mode.cross_correlation,
args.split_k_slices, 1
1, # Split k slices
1 # Groups
)
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
# User-provide inputs
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
conv_kind, problem_size
)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
if args.bias:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
).at(3)
else:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5))
tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda")
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.element_a != "int8":
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
if args.element_b != "int8":
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
if args.element_c != "int8":
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
alpha = 1.
beta = 0.
arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=operation.epilogue_type(alpha, beta)
)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
reduction_arguments = ReductionArguments(
reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=problem_size.split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
# Run the operation
operation.run(arguments)
arguments.sync()
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
if (args.activation_function != "identity"):
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
# Run the host reference module and compare to the CUTLASS result
reference = Conv2dReferenceModule(A, B, C, operation.conv_kind)
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
try:
assert torch.equal(tensor_D, tensor_D_ref)
except:
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
print("Passed.")

View File

@ -0,0 +1,192 @@
# Customizable Python Interface Examples
This directory contains examples of using the CUTLASS Python interface with a variety of configurations for kernels.
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
## GEMM Examples
The GEMM examples use numpy to create input tensors and verify the results.
### GEMM F64 Example
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
### GEMM F32 Example
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
```
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
```
### GEMM F16 Example
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
```
### GEMM BF16 Example
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
```
### GEMM Int8 Example
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
```python
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
```
### Batched & Array GEMM
Example 1: Batched GEMM
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 2: Array GEMM
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2
```
***
## GEMM Grouped Examples
The GEMM Grouped examples use numpy to create input tensors and verify the results.
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
```
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
```
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
```python
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
***
## Conv2d Example
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/get-started/locally/).
### Conv2d F32 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
```
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
```
### Conv2d F32 Wgrad
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F32 Dgrad
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F16 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
## Epilogue
### Bias
To replace C with a bias vector, add `-bias` flag.
### Activation function
Example 1: ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu
```
Example 2: leaky ReLU
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2
```
Example 3: tanh (alpha=0 to avoid saturation)
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh
```
Example 4: sigmoid
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid
```
Example 5: SiLU
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu
```
Example 6: HardSwish
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish
```
Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
### Epilogue Visitor Tree
Example 1:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2:
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 3:
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 4:
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 5:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 6:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
```

View File

@ -0,0 +1,320 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
import sys
import torch.nn.functional as F
import argparse
# parse the arguments
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorC32RSK32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'],
help='Data type of computation in the epilogue')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
"StridedDgradHorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# conv related
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
help="The type of convolution: forward propagation (fprop), \
gradient of activation (dgrad), gradient of weight (wgrad)")
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
)
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
help="This option describes iterator algorithm")
# arguments
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
Parallel is used for parallel splitK.")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if (args.activation_function == "identity"
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, stride_support=stride_support,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation,]
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
pycutlass.compiler.add_module(operations)
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
cutlass.conv.Mode.cross_correlation,
args.split_k_slices, 1
)
# User-provide inputs
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
conv_kind, problem_size
)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
if args.bias:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
).at(3)
else:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.element_a != "int8":
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
if args.element_b != "int8":
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
if args.element_c != "int8":
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda")
arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
reduction_arguments = ReductionArguments(
reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=problem_size.split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias)
if (args.activation_function != "identity"):
tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args))
try:
assert torch.equal(tensor_D, tensor_D_ref)
except:
assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2)
print("Passed.")

View File

@ -0,0 +1,445 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
import cutlass
from bfloat16 import bfloat16
import sys
import argparse
# parse the arguments
parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help="This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
parser.add_argument("-epv", "--epilogue_visitor", default=None,
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# Argument
parser.add_argument("-p", "--problem_size",
default=[128, 128, 128], nargs=3, type=int,
help="GEMM problem size M, N, K")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
help="Scaling factor of A * B")
parser.add_argument("-beta", "--beta", default=0.0, type=float,
help="Scaling factor of C")
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
GemmSplitKParallel is used for parallel splitK")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if (args.activation_function == "identity"
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
visitor = args.epilogue_visitor is not None
if args.epilogue_visitor == "ColumnReduction":
class ColumnReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + beta * c
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
return D, reduction
epilogue_functor = ColumnReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowReduction":
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowBroadcast":
class RowBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'row', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = alpha * T
Z = relu.numpy(scale_T + beta * c)
return Z, T
epilogue_functor = RowBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "ColumnBroadcast":
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
else:
epilogue_functor = epilogue_functor
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
visitor=visitor
)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation, ]
if args.gemm_mode == "GemmSplitKParallel":
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
pycutlass.compiler.add_module(operations)
# User-provide inputs
problem_size = cutlass.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(bfloat16)
else:
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(
low=-2, high=2,size=(tensor_a_size,)
).astype(getattr(np, args.element_a))
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(bfloat16)
else:
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(
low=-2, high=2, size=(tensor_b_size,)
).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.bias:
if args.layout_c == "RowMajor":
tensor_c_size = args.batch * problem_size.n()
elif args.layout_c == "ColumnMajor":
tensor_c_size = args.batch * problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
if args.epilogue_visitor == "RowReduction":
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnReduction":
cta_m = args.threadblock_shape[0]
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "RowBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
else:
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=output_op,
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices, batch=args.batch
)
if args.gemm_mode == "GemmSplitKParallel":
reduction_arguments = ReductionArguments(
operation=reduction_operation,
problem_size=[problem_size.m(), problem_size.n()],
partitions=args.split_k_slices, workspace=arguments.ptr_D,
destination=tensor_D, source=tensor_C,
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
operation.run(arguments)
if args.gemm_mode == "GemmSplitKParallel":
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
# run the host reference module
reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, reduction_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
args.alpha, args.beta
)
tensor_D_ref = tensor_D_ref.flatten()
reduction_ref = reduction_ref.flatten()
assert np.allclose(reduction_ref, reduction, atol=1e-2)
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, tensor_T_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
vector, args.alpha, args.beta)
tensor_D_ref = tensor_D_ref.flatten()
tensor_T_ref = tensor_T_ref.flatten()
assert np.array_equal(tensor_t, tensor_T_ref)
try:
assert np.array_equal(tensor_D, tensor_D_ref)
except:
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
print("Passed.")

View File

@ -0,0 +1,287 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
import csv
import sys
import argparse
# parse the arguments
parser = argparse.ArgumentParser(
description="Launch CUTLASS GEMM Grouped kernels from Python")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU. \
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
This parameter is passed in at present to match the APIs of other kernels. The parameter \
is unused within the kernel")
# precompute mode
parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
# arguments
parser.add_argument("-p", "--problem_size_dir", type=str,
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if args.activation_function == "identity":
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
operation = GemmOperationGrouped(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
precompute_mode=precompute_mode
)
if args.print_cuda:
print(operation.rt_module.emit())
pycutlass.compiler.add_module([operation, ])
reference_module = ReferenceModule(A, B, C)
# get problems
problem_sizes = []
with open(args.problem_size_dir) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
problem_sizes.append(
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
)
problem_count = len(problem_sizes)
tensor_As = []
tensor_Bs = []
tensor_Cs = []
tensor_Ds = []
problem_sizes_coord = []
tensor_D_refs = []
for problem_size in problem_sizes:
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.bias:
if args.layout_c == "RowMajor":
c_size = problem_size.n()
elif args.layout_c == "ColumnMajor":
c_size = problem_size.m()
else:
raise ValueError(args.layout_c)
else:
c_size = problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_As.append(tensor_A)
tensor_Bs.append(tensor_B)
tensor_Cs.append(tensor_C)
tensor_Ds.append(tensor_D)
tensor_D_ref = reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size,
args.alpha, args.beta, args.bias)
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
tensor_D_refs.append(tensor_D_ref)
problem_sizes_coord.append(problem_size)
arguments = GemmGroupedArguments(
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
)
operation.run(arguments)
arguments.sync()
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
try:
assert np.array_equal(tensor_d, tensor_d_ref)
except:
assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5)
print("Passed.")

View File

@ -29,417 +29,110 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
import cutlass
from bfloat16 import bfloat16
"""
Basic example of using the CUTLASS Python interface to run a GEMM
"""
import argparse
import numpy as np
import sys
import cutlass
import pycutlass
from pycutlass import *
import util
# parse the arguments
parser = argparse.ArgumentParser(
description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help="This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
parser.add_argument("-epv", "--epilogue_visitor", default=None,
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# Argument
parser.add_argument("-p", "--problem_size",
default=[128, 128, 128], nargs=3, type=int,
help="GEMM problem size M, N, K")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
help="Scaling factor of A * B")
parser.add_argument("-beta", "--beta", default=0.0, type=float,
help="Scaling factor of C")
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"],
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
GemmSplitKParallel is used for parallel splitK")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'")
parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM")
parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM")
parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM")
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
# Check that the device is of a sufficient compute capability
cc = util.get_device_cc()
assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70."
alignment = 8
assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment)
assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment)
assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
# Allocate a pool of device memory to be used by the kernel
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
# Set the compiler to use to NVCC
pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
[16, 8, 8], # Shape of the Tensor Core instruction
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
MathOperation.multiply_add
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
[128, 128, 32], # Threadblock shape
2, # Number of stages
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if (args.activation_function == "identity"
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
#
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
visitor = args.epilogue_visitor is not None
if args.epilogue_visitor == "ColumnReduction":
class ColumnReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + beta * c
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
return D, reduction
epilogue_functor = ColumnReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowReduction":
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowBroadcast":
class RowBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'row', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = alpha * T
Z = relu.numpy(scale_T + beta * c)
return Z, T
epilogue_functor = RowBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "ColumnBroadcast":
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
else:
epilogue_functor = epilogue_functor
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
arch=cc, tile_description=tile_description,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
visitor=visitor
)
epilogue_functor=epilogue_functor)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation, ]
if args.gemm_mode == "GemmSplitKParallel":
if (args.activation_function == "identity"):
epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
count=C.alignment
)
operations.append(reduction_operation)
# Compile the operation
pycutlass.compiler.add_module(operations)
# User-provide inputs
# Randomly initialize tensors
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16)
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16)
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32)
tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32)
problem_size = cutlass.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(bfloat16)
else:
tensor_A = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,))
).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(
low=-2, high=2,size=(tensor_a_size,)
).astype(getattr(np, args.element_a))
tensor_b_size = args.batch * problem_size.k() * problem_size.n()
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(bfloat16)
else:
tensor_B = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,))
).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(
low=-2, high=2, size=(tensor_b_size,)
).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.bias:
if args.layout_c == "RowMajor":
tensor_c_size = args.batch * problem_size.n()
elif args.layout_c == "ColumnMajor":
tensor_c_size = args.batch * problem_size.m()
else:
raise ValueError(args.layout_c)
else:
tensor_c_size = args.batch * problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(
low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
if args.epilogue_visitor == "RowReduction":
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnReduction":
cta_m = args.threadblock_shape[0]
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "RowBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
else:
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k)
alpha = 1.
beta = 0.
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=output_op,
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices, batch=args.batch
)
if args.gemm_mode == "GemmSplitKParallel":
reduction_arguments = ReductionArguments(
operation=reduction_operation,
problem_size=[problem_size.m(), problem_size.n()],
partitions=args.split_k_slices, workspace=arguments.ptr_D,
destination=tensor_D, source=tensor_C,
output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
bias = arguments.bias
)
output_op=operation.epilogue_type(alpha, beta))
# Run the operation
operation.run(arguments)
arguments.sync()
if args.gemm_mode == "GemmSplitKParallel":
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
# run the host reference module
# Run the host reference module and compare to the CUTLASS result
reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, reduction_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
args.alpha, args.beta
)
tensor_D_ref = tensor_D_ref.flatten()
reduction_ref = reduction_ref.flatten()
assert np.allclose(reduction_ref, reduction, atol=1e-2)
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, tensor_T_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
vector, args.alpha, args.beta)
tensor_D_ref = tensor_D_ref.flatten()
tensor_T_ref = tensor_T_ref.flatten()
assert np.array_equal(tensor_t, tensor_T_ref)
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
try:
assert np.array_equal(tensor_D, tensor_D_ref)
except:
assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5)
print("Passed.")

View File

@ -29,253 +29,125 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import pycutlass
from pycutlass import *
import csv
"""
Basic example of using the CUTLASS Python interface to run a grouped GEMM
"""
import argparse
import numpy as np
import sys
# parse the arguments
parser = argparse.ArgumentParser(
description="Launch CUTLASS GEMM Grouped kernels from python")
import cutlass
import pycutlass
from pycutlass import *
import util
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU. \
NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \
This parameter is passed in at present to match the APIs of other kernels. The parameter \
is unused within the kernel")
# precompute mode
parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
# arguments
parser.add_argument("-p", "--problem_size_dir", type=str,
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector")
# Activation function
parser.add_argument("-activ", "--activation_function", default="identity",
choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function")
parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float,
help="addition arguments for activation")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python")
parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
# Check that the device is of a sufficient compute capability
cc = util.get_device_cc()
assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70."
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
# Allocate a pool of device memory to be used by the kernel
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
# Set the compiler to use to NVCC
pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
alignment = 1
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
[16, 8, 8], # Shape of the Tensor Core instruction
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
MathOperation.multiply_add
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
[128, 128, 32], # Threadblock shape
2, # Number of stages
[2, 2, 1], # Number of warps within each dimension of the threadblock shape
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
if args.activation_function == "identity":
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
operation = GemmOperationGrouped(
arch=args.compute_capability, tile_description=tile_description,
arch=cc, tile_description=tile_description,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
precompute_mode=precompute_mode
)
epilogue_functor=epilogue_functor,
precompute_mode=SchedulerMode.Device)
if args.print_cuda:
print(operation.rt_module.emit())
pycutlass.compiler.add_module([operation, ])
operations = [operation, ]
reference_module = ReferenceModule(A, B, C)
# get problems
problem_sizes = []
with open(args.problem_size_dir) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
problem_sizes.append(
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
)
# Compile the operation
pycutlass.compiler.add_module(operations)
# Initialize tensors for each problem in the group
problem_sizes = [
cutlass.gemm.GemmCoord(128, 128, 64),
cutlass.gemm.GemmCoord(512, 256, 128)
]
problem_count = len(problem_sizes)
alpha = 1.
beta = 0.
tensor_As = []
tensor_Bs = []
tensor_Cs = []
tensor_Ds = []
problem_sizes_coord = []
tensor_D_refs = []
reference = ReferenceModule(A, B, C)
for problem_size in problem_sizes:
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.bias:
if args.layout_c == "RowMajor":
c_size = problem_size.n()
elif args.layout_c == "ColumnMajor":
c_size = problem_size.m()
else:
raise ValueError(args.layout_c)
else:
c_size = problem_size.m() * problem_size.n()
if args.element_c == "bfloat16":
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(bfloat16)
else:
tensor_C = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(c_size,))
).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(
low=-2, high=2, size=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
tensor_D = np.zeros(
shape=(problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
# Randomly initialize tensors
m = problem_size.m()
n = problem_size.n()
k = problem_size.k()
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16)
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16)
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32)
tensor_D = np.zeros(shape=(m * n,)).astype(np.float32)
tensor_As.append(tensor_A)
tensor_Bs.append(tensor_B)
tensor_Cs.append(tensor_C)
tensor_Ds.append(tensor_D)
tensor_D_ref = reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size,
args.alpha, args.beta, args.bias)
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
# Run the reference GEMM
tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta)
tensor_D_refs.append(tensor_D_ref)
problem_sizes_coord.append(problem_size)
arguments = GemmGroupedArguments(
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=operation.epilogue_type(alpha, beta)
)
# Run the operation
operation.run(arguments)
arguments.sync()
# Compare the CUTLASS result to the host reference result
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
try:
assert np.array_equal(tensor_d, tensor_d_ref)

View File

@ -0,0 +1,60 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for interacting with device
"""
from cuda import cudart
# Raises an exception if `result` returned an error. Otherwise returns the result.
def check_cuda_errors(result: list):
# `result` is of the format : (cudaError_t, result...)
err = result[0]
if err.value:
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
# Returns the integer representation of the device compute capability
def get_device_cc(device: int = 0):
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
major = str(deviceProp.major)
minor = str(deviceProp.minor)
return int(major + minor)

View File

@ -0,0 +1,44 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
41_fused_multi_head_attention_fixed_seqlen
fused_multihead_attention_fixed_seqlen.cu
)
cutlass_example_add_executable(
41_fused_multi_head_attention_variable_seqlen
fused_multihead_attention_variable_seqlen.cu
)
add_custom_target(41_fused_multi_head_attention
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
41_fused_multi_head_attention_variable_seqlen
)

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/functional.h"

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <float.h>
#include <stdio.h>

View File

@ -0,0 +1,284 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "fmha_grouped.h"
#include "gemm_kernel_utils.h"
#include "find_default_mma.h"
#include "attention_scaling_coefs_updater.h"
#include "mma_from_smem.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag_,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration,
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
>
struct DefaultFMHAGrouped {
using scalar_t = scalar_t_;
using accum_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_;
static int const kWarpSize = 32;
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = gemm_kernel_utils::DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = scalar_t;
using ElementAccumulator = accum_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
ElementA,
ElementB,
ElementC,
ElementAccumulator
>;
static int const kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static int const kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
static int const kStages = DefaultConfig::kStages;
using Operator = typename GemmType::Operator;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
ElementAccumulator,
kWarpSize>::Updater;
static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/*
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = typename MM0::GemmType;
using OpClass = typename GemmType::OpClass;
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = output_accum_t;
using ElementAccumulator = accum_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
ElementA,
ElementB,
ElementC,
ElementAccumulator
>;
static int const kAlignmentA = DefaultConfig::kAlignmentA;
static int const kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = typename MM0::ThreadblockShape;
using WarpShape = typename MM0::WarpShape;
using InstructionShape = typename MM0::InstructionShape;
using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp;
static int const kStages = DefaultConfig::kStages;
using Operator = typename GemmType::Operator;
using ThreadblockSwizzle = void; // Swizzling is unused
static bool const kSplitKSerial = false;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(WarpCount::kCount == kNumWarpsPerBlock, "");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
/// Define the kernel in terms of the default kernel
using FMHAKernel = kernel::FMHAGrouped<
MM0,
MM1,
scalar_t,
accum_t,
output_t,
output_accum_t,
kSingleValueIteration,
GroupScheduleMode_
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.

View File

@ -1,8 +1,39 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instanciation priority rules, it will only create an MmaMultiStage with
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
so we just copy-pasted some code from `default_mma.h` and

View File

@ -0,0 +1,839 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped FMHA kernel
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "fmha_grouped_problem_visitor.h"
#include "gemm_kernel_utils.h"
#include "epilogue_rescale_output.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename MM0_, ///! Structure for computing P = Q @ K
typename MM1_, ///! Structure for computing O = P @ V
typename scalar_t_,
typename accum_t_,
typename output_t_,
typename output_accum_t_,
bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct FMHAGrouped {
public:
using MM0 = MM0_;
using MM1 = MM1_;
using scalar_t = scalar_t_;
using accum_t = accum_t_;
using output_t = output_t_;
using output_accum_t = output_accum_t_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
// Parameters to satisfy BaseGrouped
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = accum_t;
using LayoutA = typename MM0::LayoutA;
using LayoutB = typename MM0::ElementB;
using LayoutC = typename MM1::ElementC;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static int const kAlignmentA = MM0::kAlignmentA;
static int const kAlignmentB = MM0::kAlignmentB;
static int const kAlignmentC = 1;
using Mma = typename MM1::Mma;
using EpilogueOutputOp = typename MM1::EpilogueOutputOp;
using ThreadblockSwizzle = void;
using Operator = typename MM1::Operator;
using WarpShape = typename MM1::WarpShape;
using InstructionShape = typename MM1::InstructionShape;
using ElementQ = scalar_t;
using ElementK = scalar_t;
using ElementP = accum_t;
using ElementV = scalar_t;
using ElementO = output_t;
using ElementOAccum = output_accum_t;
using ElementAccumulator = accum_t;
using LayoutQ = typename MM0::LayoutA;
using LayoutK = typename MM0::LayoutB;
using LayoutP = typename MM0::LayoutC;
using LayoutV = typename MM1::LayoutB;
using LayoutO = typename MM1::LayoutC;
static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<ElementV>::value == 16);
static int const kAlignmentQ = MM0::kAlignmentA;
static int const kAlignmentK = MM0::kAlignmentB;
static int const kAlignmentV = 1;
using ThreadblockShape = typename MM0::ThreadblockShape;
static int const kQueriesPerBlock = ThreadblockShape::kM;
static int const kKeysPerBlock = ThreadblockShape::kN;
/// Warp count (concept: GemmShape)
using WarpCount = typename MM1::WarpCount;
static int const kThreadsPerWarp = 32;
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
using ProblemVisitor = FMHAGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord *problem_sizes0;
GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementV ** ptr_V;
ElementO ** ptr_O;
ElementOAccum ** ptr_O_accum;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutO::Stride::LongIndex *ldo;
// Whether causal masking is to be performed
bool causal;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_O_accum(nullptr),
ldq(nullptr),
ldk(nullptr),
ldv(nullptr),
ldo(nullptr),
causal(false),
host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord *problem_sizes0,
GemmCoord *problem_sizes1,
int problem_count,
int threadblock_count,
ElementQ ** ptr_Q,
ElementK ** ptr_K,
ElementP ** ptr_P,
ElementV ** ptr_V,
ElementO ** ptr_O,
ElementOAccum ** ptr_O_accum,
typename LayoutQ::Stride::LongIndex *ldq,
typename LayoutK::Stride::LongIndex *ldk,
typename LayoutP::Stride::LongIndex *ldp,
typename LayoutV::Stride::LongIndex *ldv,
typename LayoutO::Stride::LongIndex *ldo,
bool causal,
GemmCoord *host_problem_sizes=nullptr
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_P(ptr_P),
ptr_V(ptr_V),
ptr_O(ptr_O),
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O),
ldq(ldq),
ldk(ldk),
ldv(ldv),
ldo(ldo),
causal(causal),
host_problem_sizes(host_problem_sizes)
{
}
bool __host__ check_supported() {
CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ);
CHECK_ALIGNED_PTR(ptr_K, kAlignmentK);
CHECK_ALIGNED_PTR(ptr_V, kAlignmentV);
XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned");
return true;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementV ** ptr_V;
ElementO ** ptr_O;
ElementOAccum ** ptr_O_accum;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutO::Stride::LongIndex *ldo;
bool causal;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_O_accum(nullptr),
ldq(nullptr),
ldk(nullptr),
ldv(nullptr),
ldo(nullptr),
causal(false)
{ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args,
void *workspace = nullptr,
int tile_count = 0):
problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
ptr_V(args.ptr_V),
ptr_O(args.ptr_O),
ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O),
ldq(args.ldq),
ldk(args.ldk),
ldv(args.ldv),
ldo(args.ldo),
causal(args.causal)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0,
args.problem_sizes1,
args.problem_count,
workspace, tile_count);
threadblock_count = args.threadblock_count;
ptr_Q = args.ptr_Q;
ptr_K = args.ptr_K;
ptr_P = args.ptr_P;
ptr_V = args.ptr_V;
ptr_O = args.ptr_O;
ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O;
ldq = args.ldq;
ldk = args.ldk;
ldv = args.ldv;
ldo = args.ldo;
causal = args.causal;
}
};
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
using SharedStorage = typename cutlass::platform::conditional<
kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
private:
// Parameters to be used by an individual tile
struct TileParams {
CUTLASS_HOST_DEVICE
static int query_start(int threadblock_idx) {
return threadblock_idx * kQueriesPerBlock;
}
// Returns whether this threadblock computes within the number of queries,
// which is determined by the M dimension of problem 0
CUTLASS_HOST_DEVICE
static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) {
return query_start(threadblock_idx) < problem_size0.m();
}
CUTLASS_HOST_DEVICE
static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) {
return problem_size0.m() - query_start(threadblock_idx);
}
CUTLASS_HOST_DEVICE
static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) {
int nk = problem_size0.n();
if (causal) {
nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk);
}
return nk;
}
};
public:
//
// Methods
//
CUTLASS_DEVICE
FMHAGrouped() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x;
}
static CUTLASS_DEVICE int8_t warp_id() {
return threadIdx.x / kThreadsPerWarp;
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x % kThreadsPerWarp;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size0 = problem_visitor.problem_size0();
GemmCoord problem_size1 = problem_visitor.problem_size1();
const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
if (!TileParams::can_compute(threadblock_idx, problem_size0)) {
problem_visitor.advance(gridDim.x);
continue;
}
const int32_t problem_idx = problem_visitor.problem_index();
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
}
ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0);
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]},
ptr_O,
typename OutputTileIterator::TensorCoord{
num_queries, problem_size1.n()},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]},
ptr_O_accum,
typename OutputTileIteratorAccum::TensorCoord{
num_queries, problem_size1.n()},
thread_id(),
{0, col});
};
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal);
for (int32_t iter_key_start = 0; iter_key_start < num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
(int32_t)kKeysPerBlock, num_keys - iter_key_start);
int32_t const& problem_size_0_k = problem_size0.k();
int32_t const& problem_size_1_n = problem_size1.n();
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx];
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])),
ptr_Q,
{problem_size_0_m, problem_size_0_k},
thread_id(),
{0, 0});
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])),
params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx],
{problem_size_0_k, problem_size_0_n},
thread_id(),
{0, 0});
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), warp_id(), lane_id());
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(warp_id() % MM0::Mma::WarpCount::kM),
(warp_id() / MM0::Mma::WarpCount::kM)
};
// Mask out last if causal
if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f / cutlass::fast_sqrt(float(problem_size0.k())));
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = warp_id() %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kKeepOutputInRF ? 1
: ceil_div(
(int64_t)problem_size_1_n,
int64_t(MM1::ThreadblockShape::kN));
// Iterate over the N dimension of GEMM1
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
output_accum_t,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = gemm_kernel_utils::call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kKeepOutputInRF) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
const bool kIsFirst = true;
const bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,178 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Scheduler for grouped FMHA
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Helper for correctly representing problem sizes in grouped kernels
template <typename ThreadblockShape>
struct FMHAGroupedProblemSizeHelper {
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) {
// FMHA only partitions tiles across the M dimension.
return cutlass::gemm::GemmCoord(
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1);
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {}
CUTLASS_HOST_DEVICE
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) {
return grid.m() * grid.n();
}
};
} // namespace detail
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor<
detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>;
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using BaseParams = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
struct Params {
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
problem_count(0), workspace(nullptr), tile_count(0) { }
/// Ctor
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const *problem_sizes0,
cutlass::gemm::GemmCoord const *problem_sizes1,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count)
{}
/// Convert the FMHA-specific parameters to those used by the base class
CUTLASS_HOST_DEVICE
BaseParams to_base() const {
return BaseParams(// Set problem_sizes as problem_sizes1 because these determine
// shape of the final output of FMHA
problem_sizes1,
problem_count,
workspace,
tile_count);
}
};
//
// Methods
//
CUTLASS_DEVICE
FMHAGroupedProblemVisitor(
Params const &params_,
SharedStorage &shared_storage_,
int32_t block_idx
): Base (
params_.to_base(),
shared_storage_, block_idx),
problem_sizes0(params_.problem_sizes0),
problem_sizes1(params_.problem_sizes1)
{}
/// Returns the problem size 0 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size0() const {
GemmCoord problem = problem_sizes0[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
/// Returns the problem size 1 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size1() const {
GemmCoord problem = problem_sizes1[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -77,21 +77,17 @@
Examples:
# Run an attention example with default setup
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention
$ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen
# Run an attention example with custom setup
$ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true
$ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true
Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers).
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#include <unordered_map>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
@ -241,8 +237,8 @@ struct Options {
for (int i = 0; i < batch_size; ++i) {
// problems belonging to the same batch share the same seq len
int m_real = seq_length; // (rand() % seq_length);
int mkv_real = seq_length_kv; // (rand() % seq_length_kv);
int m_real = seq_length;
int mkv_real = seq_length_kv;
int m = (m_real + alignment - 1) / alignment * alignment;
int mkv = (mkv_real + alignment - 1) / alignment * alignment;
int k0 = head_size;
@ -260,7 +256,6 @@ struct Options {
problem_sizes0_real.push_back(problem0_real);
problem_sizes1_real.push_back(problem1_real);
}
}
}
}
@ -268,7 +263,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "42_fused_multi_head_attention\n\n"
out << "41_fused_multi_head_attention_fixed_seqlen\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --head_number=<int> Head number in multi-head attention (default: --head_number=12)\n"
@ -276,7 +271,7 @@ struct Options {
<< " --head_size=<int> Head size in multi-head attention (default: --head_size=64)\n"
<< " --head_size_v=<int> Head size in multi-head attention for V (default: --head_size_v=head_size)\n"
<< " --seq_length=<int> Sequence length in multi-head attention for Q (default: --seq_length=1024)\n"
<< " --seq_length_kv=<int> Sequence length in multi-head attention for K/V(default: --seq_length_kv=seq_length)\n"
<< " --seq_length_kv=<int> Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n"
<< " --use_mask=<bool> If true, performs padding-like masking in softmax.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n"
@ -342,8 +337,7 @@ public:
using ElementSoftmaxCompute = typename Attention::accum_t;
using LayoutQ = cutlass::layout::RowMajor;
using LayoutK = cutlass::layout::RowMajor;
using LayoutK_T = cutlass::layout::ColumnMajor; // transposed
using LayoutK = cutlass::layout::ColumnMajor;
using LayoutP = cutlass::layout::RowMajor;
using LayoutV = cutlass::layout::RowMajor;
using LayoutO = cutlass::layout::RowMajor;
@ -516,7 +510,7 @@ private:
auto problem1 = options.problem_sizes1.at(i);
ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0);
ldk_host.at(i) = LayoutK::packed({problem0.n(), problem0.k()}).stride(0);
ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0);
ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0);
ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0);
ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0);
@ -541,7 +535,6 @@ private:
total_elements_P += elements_P;
total_elements_V += elements_V;
total_elements_O += elements_O;
}
problem_sizes_device0.reset(problem_count());
@ -641,7 +634,7 @@ private:
float abs_diff = fabs(diff);
float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f);
float relative_diff = abs_diff / abs_ref;
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) {
if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) {
printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
return false;
}
@ -661,7 +654,7 @@ private:
cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i);
LayoutQ layout_Q(ldq_host.at(i));
LayoutK_T layout_K(ldk_host.at(i));
LayoutK layout_K(ldk_host.at(i));
LayoutP layout_P(ldp_host.at(i));
LayoutV layout_V(ldv_host.at(i));
LayoutO layout_O(ldo_host.at(i));
@ -673,7 +666,7 @@ private:
MatrixCoord extent_O{problem1.m(), problem1.k()};
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
cutlass::TensorView<ElementK, LayoutK_T> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
cutlass::TensorView<ElementP, LayoutP> view_P(block_P.get() + offset_P.at(i), layout_P, extent_P);
cutlass::TensorView<ElementV, LayoutV> view_V(block_V.get() + offset_V.at(i), layout_V, extent_V);
@ -686,7 +679,7 @@ private:
// Reference GEMM
cutlass::reference::device::GemmComplex<
ElementQ, LayoutQ,
ElementK, LayoutK_T,
ElementK, LayoutK,
ElementP, LayoutP,
ElementCompute, ElementAccumulator
>(
@ -988,6 +981,40 @@ public:
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration
>
int run_attention(Options& options) {
using Attention = AttentionKernel<
cutlass::half_t, // scalar_t
cutlass::arch::Sm80, // ArchTag
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration
>;
//
// Test and profile
//
TestbedAttention<Attention> testbed(options);
Result result = testbed.profile_grouped();
if (!result.passed) {
std::cout << "Profiling CUTLASS attention has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
@ -1041,52 +1068,25 @@ int main(int argc, char const **args) {
std::cerr << "--alignment=1 is the only supported value\n";
return -2;
}
using ArchTag = cutlass::arch::Sm80;
constexpr bool kIs64x64 = true;
// Set grid size
constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32;
constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128;
if (kIs64x64 && options.head_size_v > kKeysPerBlock) {
std::cerr << "WARNING: you will get better performance with `kIs64x64=false`\n";
// Determine kernel configuration based on head size.
// If head size is less than or equal to 64, each block operates over 64 queries and
// 64 keys, and parital results can be stored in the register file.
// If head size is greater than 64, each block operates over 32 queries and 128 keys,
// and partial results are stored in shared memory.
if (options.head_size_v > 64) {
static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
} else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
}
} else {
static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
}
constexpr bool kSingleValueIteration = true;
if (kSingleValueIteration && options.head_size_v > kKeysPerBlock) {
std::cerr << "ERROR : Use kSingleValueIteration to keep output in RF. " \
"This requires to have `head_size <= kKeysPerBlock` " \
"but head_size_v=" << options.head_size_v << " and kKeysPerBlock=" << kKeysPerBlock << "\n";
return -2;
}
if (!kSingleValueIteration && options.head_size_v <= kKeysPerBlock) {
std::cerr << "WARNING: you will get better performance with `kSingleValueIteration=true` (keeps the output in RF rather than GMEM)\n";
}
using Attention = AttentionKernel<
cutlass::half_t, // scalar_t
ArchTag,
true, // memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration
>;
//
// Test and profile
//
TestbedAttention<Attention> testbed(options);
Result result = testbed.profile_grouped();
if (!result.passed) {
std::cout << "Profiling CUTLASS attention has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "custom_mma_multistage.h"

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "predicated_tile_access_iterator_residual_last.h"
#include "predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass

View File

@ -1,3 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#ifdef HAS_PYTORCH

View File

@ -1,626 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines the FusedMultiHeadAttention Class
The class contains the following:
1) GEMM0 with epilogue fusion,
2) GEMM1 with mainloop fusion, and
3) A lightweight full softmax reduction kernel.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h"
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
#include "gemm_grouped_with_softmax_visitor.h"
namespace cutlass {
template <
typename ElementQ_,
typename LayoutQ_,
typename ElementK_,
typename LayoutK_,
typename ElementP_,
typename LayoutP_,
typename ElementCompute_,
typename OperatorClass_,
typename ArchTag_,
typename ThreadblockShape0_,
typename ThreadblockShape1_,
typename WarpShape0_,
typename WarpShape1_,
typename InstructionShape_,
int kStages0_,
int kStages1_,
bool UseMasking_ = false,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
int Alignment = 128 / cutlass::sizeof_bits<ElementQ_>::value,
typename ElementSoftmax_ = ElementP_
>
class FusedMultiHeadAttention {
public:
using ElementQ = ElementQ_;
using ElementK = ElementK_;
using ElementP = ElementP_;
using ElementV = ElementK;
using ElementOutput = ElementP;
using ElementAccumulator = ElementCompute_;
using LayoutQ = LayoutQ_;
using LayoutK = LayoutK_;
using LayoutP = LayoutP_;
using LayoutV = LayoutK;
using LayoutO = LayoutP;
using ElementNorm = cutlass::half_t;
using ElementSum = cutlass::half_t;
using ElementSoftmaxCompute = float;
using LayoutNorm = cutlass::layout::RowMajor;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape0 = ThreadblockShape0_;
using WarpShape0 = WarpShape0_;
using ThreadblockShape1 = ThreadblockShape1_;
using WarpShape1 = WarpShape1_;
static int const Stages0 = kStages0_;
static int const Stages1 = kStages1_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>;
using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration<
OperatorClass, ArchTag, ElementQ, ElementK, ElementP,
ElementAccumulator>::Operator;
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutP, cutlass::layout::ColumnMajor>::value;
static bool const kUseMasking = UseMasking_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_;
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
ElementQ,
LayoutQ,
cutlass::ComplexTransform::kNone,
8,
ElementK,
LayoutK,
cutlass::ComplexTransform::kNone,
8,
LayoutP,
kInternalTranspose
>;
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
typename MapArguments::ElementA,
typename MapArguments::LayoutA,
MapArguments::kAlignmentA,
typename MapArguments::ElementB,
typename MapArguments::LayoutB,
MapArguments::kAlignmentB,
ElementP,
typename MapArguments::LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape0,
WarpShape0,
InstructionShape,
EpilogueOutputOp0,
ThreadblockSwizzle,
Stages0,
true,
Operator,
cutlass::gemm::SharedMemoryClearOption::kNone
>::GemmKernel;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
ThreadblockShape0,
DefaultGemmKernel::kThreadCount,
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
typename EpilogueOutputOp0::ElementCompute,
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
EpilogueOutputOp0,
kUseMasking
>;
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
EpilogueVisitor,
typename DefaultGemmKernel::Epilogue
>::Epilogue;
using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor<
typename DefaultGemmKernel::Mma,
Epilogue,
ThreadblockSwizzle,
kGroupScheduleMode0,
kInternalTranspose,
kUseMasking
>;
using GemmGrouped0 = cutlass::gemm::device::GemmGrouped<GemmKernel0>;
using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
ElementNorm,
ElementSum,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape,
true
>;
using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion<
ElementP,
LayoutP,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementQ>::value,
ElementV,
LayoutV,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementK>::value,
ElementNorm,
LayoutNorm,
ElementOutput,
LayoutO,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape1,
WarpShape1,
InstructionShape,
EpilogueOutputOp1,
ThreadblockSwizzle,
Stages1,
kGroupScheduleMode1
>::GemmKernel;
using GemmGrouped1 = cutlass::gemm::device::GemmGrouped<GemmKernel1>;
public:
/// Arguments class
struct Arguments {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Arguments reduction;
//
// Methods
//
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Arguments(
cutlass::gemm::GemmCoord *problem_sizes0,
cutlass::gemm::GemmCoord *problem_sizes1,
int problem_count,
int threadblock_count,
ElementQ ** ptr_Q,
ElementK ** ptr_K,
ElementP ** ptr_P,
ElementP ** ptr_V,
ElementP ** ptr_O,
ElementNorm **ptr_Max,
ElementSum **ptr_Sum,
ElementP *block_P,
ElementNorm *block_Norm,
ElementSum *block_Sum,
int64_t *offset_P,
int64_t *offset_Norm_Device,
int64_t *offset_Sum_Device,
typename LayoutQ::Stride::LongIndex *ldq,
typename LayoutK::Stride::LongIndex *ldk,
typename LayoutP::Stride::LongIndex *ldp,
typename LayoutP::Stride::LongIndex *ldv,
typename LayoutP::Stride::LongIndex *ldo,
ElementAccumulator alpha0,
ElementAccumulator alpha1,
ElementAccumulator beta,
int head_number,
int batch_size,
int seq_length,
cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_P(ptr_P),
ptr_V(ptr_V),
ptr_O(ptr_O),
ptr_Max(ptr_Max),
ptr_Sum(ptr_Sum),
block_P(block_P),
block_Norm(block_Norm),
block_Sum(block_Sum),
offset_P(offset_P),
offset_Norm_Device(offset_Norm_Device),
offset_Sum_Device(offset_Sum_Device),
ldq(ldq),
ldk(ldk),
ldp(ldp),
ldv(ldv),
ldo(ldo),
alpha0(alpha0),
alpha1(alpha1),
beta(beta),
head_number(head_number),
batch_size(batch_size),
seq_length(seq_length),
problem_sizes0_host(problem_sizes0_host),
problem_sizes1_host(problem_sizes1_host),
problem_sizes0_real(problem_sizes0_real),
reduction(
problem_sizes0,
block_Norm,
block_Sum,
offset_Norm_Device,
offset_Sum_Device
)
{
}
};
struct Params {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Params reduction;
Params():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
problem_sizes0(nullptr),
problem_sizes1(nullptr),
problem_sizes0_real(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Params(Arguments const &args, void *workspace = nullptr):
problem_sizes0(args.problem_sizes0),
problem_sizes1(args.problem_sizes1),
problem_count(args.problem_count),
threadblock_count(args.threadblock_count),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
ptr_V(args.ptr_V),
ptr_O(args.ptr_O),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
block_P(args.block_P),
block_Norm(args.block_Norm),
block_Sum(args.block_Sum),
offset_P(args.offset_P),
offset_Norm_Device(args.offset_Norm_Device),
offset_Sum_Device(args.offset_Sum_Device),
ldq(args.ldq),
ldk(args.ldk),
ldp(args.ldp),
ldv(args.ldv),
ldo(args.ldo),
problem_sizes0_host(args.problem_sizes0_host),
problem_sizes1_host(args.problem_sizes1_host),
problem_sizes0_real(args.problem_sizes0_real),
alpha0(args.alpha0),
alpha1(args.alpha1),
beta(args.beta),
head_number(args.head_number),
batch_size(args.batch_size),
seq_length(args.seq_length),
reduction(args.reduction)
{
}
};
private:
Params params_;
GemmGrouped0 gemm_grouped0;
GemmGrouped1 gemm_grouped1;
public:
/// Ctor
FusedMultiHeadAttention() {
}
/// Initialize
Status initialize(Arguments const &args,
void *workspace0 = nullptr,
void *workspace1 = nullptr) {
params_ = Params(args);
typename GemmGrouped0::Arguments args_gemm0(
params_.problem_sizes0,
params_.problem_count,
params_.threadblock_count,
params_.ptr_Q,
params_.ptr_K,
params_.ptr_P,
params_.ptr_P,
params_.ptr_Max,
params_.ptr_Sum,
params_.ldq,
params_.ldk,
params_.ldp,
params_.ldp,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments(
{
params_.alpha0,
params_.beta
}
),
params_.problem_sizes0_host,
params_.problem_sizes0_real
);
Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0);
typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta);
typename GemmGrouped1::Arguments args_gemm1(
params_.problem_sizes1,
params_.problem_count,
params_.threadblock_count,
epilogue_op1,
params_.ptr_P,
params_.ptr_V,
params_.ptr_O,
params_.ptr_O,
(void**)params_.ptr_Max,
(void**)params_.ptr_Sum,
params_.ldp,
params_.ldv,
params_.ldo,
params_.ldo,
params_.problem_sizes1_host
);
Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1);
if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) {
return cutlass::Status::kSuccess;
}else{
if (result0 != cutlass::Status::kSuccess) {
return result0;
}else{
return result1;
}
}
}
/// Run
Status run(cudaStream_t stream = nullptr) {
Status result = gemm_grouped0.run();
cudaError_t error_info;
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
int thread_per_block = 1024;
dim3 final_reduction_grid(params_.head_number * params_.batch_size);
dim3 final_reduction_block(thread_per_block);
cutlass::Kernel<ApplyFinalReductionDevice><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream
>>>(params_.reduction);
error_info = cudaGetLastError();
if (error_info != cudaSuccess) {
return cutlass::Status::kErrorInternal;
}
result = gemm_grouped1.run();
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
return cutlass::Status::kSuccess;
}
/// Function call operator
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
};
}

View File

@ -1,522 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped GEMM kernel with epilogue visitor customized for softmax
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed_ = false,
bool UseMask_ = false
>
struct GemmGroupedWithEpilogueVistor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
using EpilogueVisitor = typename Epilogue::Visitor;
using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor;
static bool const kTransposed = Transposed_;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<
typename Mma::IteratorA::Element,
typename Mma::IteratorA::Layout,
Mma::kTransformA,
Mma::IteratorA::AccessType::kElements,
typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout,
Mma::kTransformB,
Mma::IteratorB::AccessType::kElements,
typename Mma::LayoutC,
kTransposed
>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename MapArguments::LayoutC;
using ElementNorm = typename EpilogueVisitor::ElementNorm;
using ElementSum = typename EpilogueVisitor::ElementSum;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = GemmGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount,
kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord *problem_sizes;
// when using mask, real problem sizes may not be aligned
// then we need to mask out unpadded elements in softmax
GemmCoord *problem_sizes_real;
int problem_count;
int threadblock_count;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
typename EpilogueVisitor::Arguments epilogue_visitor;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
problem_count(0),
threadblock_count(0),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord *problem_sizes,
int problem_count,
int threadblock_count,
ElementA ** ptr_A,
ElementB ** ptr_B,
ElementC ** ptr_C,
ElementC ** ptr_D,
ElementNorm **ptr_Max,
ElementSum **ptr_Sum,
typename LayoutA::Stride::LongIndex *lda,
typename LayoutB::Stride::LongIndex *ldb,
typename LayoutC::Stride::LongIndex *ldc,
typename LayoutC::Stride::LongIndex *ldd,
typename EpilogueVisitor::Arguments epilogue_visitor_,
GemmCoord *host_problem_sizes=nullptr,
GemmCoord *problem_sizes_real=nullptr
):
problem_sizes(problem_sizes),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_A(ptr_A),
ptr_B(ptr_B),
ptr_C(ptr_C),
ptr_D(ptr_D),
ptr_Max(ptr_Max),
ptr_Sum(ptr_Sum),
lda(lda),
ldb(ldb),
ldc(ldc),
ldd(ldd),
epilogue_visitor(epilogue_visitor_),
host_problem_sizes(host_problem_sizes),
problem_sizes_real(problem_sizes_real)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
GemmCoord *problem_sizes_real;
int threadblock_count;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
problem_sizes_real(problem_sizes_real)
{ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0):
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
ptr_A(args.ptr_A),
ptr_B(args.ptr_B),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
lda(args.lda),
ldb(args.ldb),
ldc(args.ldc),
ldd(args.ldd),
epilogue_visitor(args.epilogue_visitor),
problem_sizes_real(args.problem_sizes_real)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int32_t tile_count = -1) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_Max = args.ptr_Max;
ptr_Sum = args.ptr_Sum;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
problem_sizes_real = args.problem_sizes_real;
}
};
/// Shared memory storage structure
struct SharedStorage {
union {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmGroupedWithEpilogueVistor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Mma::LayoutC;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
0);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{
0,
threadblock_offset.n()
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
{problem_size.k(), problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
ElementC *ptr_C = params.ptr_C[problem_idx];
ElementC *ptr_D = params.ptr_D[problem_idx];
ElementNorm *ptr_Max = params.ptr_Max[problem_idx];
ElementSum *ptr_Sum = params.ptr_Sum[problem_idx];
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m();
typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C);
typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D);
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.kernel.epilogue.visitor,
problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params_C,
params_D,
ptr_C,
ptr_D,
ptr_Max,
ptr_Sum,
threadblock_offset.mn(),
column_offset,
params.problem_sizes_real[problem_idx].mn()
);
// Construct the epilogue
Epilogue epilogue(
shared_storage.kernel.epilogue.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor
epilogue(epilogue_visitor, accumulators);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -30,7 +30,7 @@
cutlass_example_add_executable(
42_fused_multi_head_attention
fused_multihead_attention.cu
42_ampere_tensorop_group_conv
ampere_tensorop_group_conv.cu
)

View File

@ -0,0 +1,706 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run group convolution kernels using functions and data structures
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
There are 2 group conv mode:
1. cutlass::conv::GroupMode::kSingleGroup
This mode is for large K problem size: k_per_group (K/groups) equals or larger than
threadblock_tile_N. One or multiple threadblocks calculate data of one group.
2. cutlass::conv::GroupMode::kMultipleGroup
This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N.
One threadblock will calculate data from more than one group.
Function profile_convolution_selecter() shows how to choose kernel with different group mode according
to problem size and threadblock_tile size.
*/
#include <iostream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/kernel/default_conv2d_group_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/device/convolution.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 3;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
// Analytic kernel and operation for single group problem size
using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kSingleGroup,
cutlass::conv::IteratorAlgorithm::kAnalytic
>::Kernel;
using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<AnalyticSingleGroupKernel>;
// Analytic kernel and operation for multiple group problem size
using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kMultipleGroup,
cutlass::conv::IteratorAlgorithm::kAnalytic
>::Kernel;
using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<AnalyticMultipleGroupKernel>;
// Optimized kernel and operation for single group problem size
using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::GroupMode::kSingleGroup,
cutlass::conv::IteratorAlgorithm::kOptimized
>::Kernel;
using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution<OptimizedSingleGroupKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::Tensor4DCoord input_size;
cutlass::Tensor4DCoord filter_size;
cutlass::Tensor4DCoord padding;
cutlass::MatrixCoord conv_stride;
cutlass::MatrixCoord dilation;
int groups;
bool reference_check;
bool measure_performance;
int iterations;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool optimized;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
groups(1),
reference_check(false),
measure_performance(false),
iterations(20),
alpha(1),
beta(0),
optimized(false) { }
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
bool valid() {
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
cutlass::Tensor4DCoord input_size,
cutlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("optimized")) {
optimized = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
cmd.get_cmd_line_argument("g", groups);
filter_size.c() = input_size.c() / groups;
cmd.get_cmd_line_argument("u", conv_stride.row());
cmd.get_cmd_line_argument("v", conv_stride.column());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "42_ampere_tensorop_group_conv example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward grouped convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --g=<int> Conv groups G\n\n"
<< " --u=<int> Conv stride_h\n\n"
<< " --v=<int> Conv stride_w\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n"
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n"
<< "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
cutlass::Tensor4DCoord output_size() const {
return cutlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cutlass::Status reference_check;
cudaError_t error;
Result():
runtime_ms(0),
gflops(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.groups << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
template <typename Conv2dOperation>
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the CUTLASS Utilities.
//
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0);
// Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(
tensor_d.host_view());
// Fill tensor D for reference on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for CUTLASS Convolution
//
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices,
options.groups
);
// Construct Conv2dOperation::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename Conv2dOperation::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize CUTLASS Convolution
//
Conv2dOperation implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(result.status);
//
// Launch initialized CUTLASS kernel
//
result.status = implicit_gemm_op();
CUTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on device...\n";
// Compute with reference implementation
cutlass::reference::device::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_ref_d.device_ref(),
options.alpha,
options.beta
);
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
if (!passed) {
result.reference_check = cutlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
} else {
result.reference_check = cutlass::Status::kSuccess;
std::cout << "Passed.\n";
}
} else {
result.reference_check = cutlass::Status::kInvalid;
}
//
// Performance measurement
//
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
CUTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
Result profile_convolution_selecter(Options const &options) {
int k_per_group = options.filter_size.n() / options.groups;
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode
if (options.optimized) {
std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl;
exit(-1);
} else {
std::cout << "Select AnalyticMultipleGroupOperation\n";
return profile_convolution<AnalyticMultipleGroupOperation>(options);
}
} else { // SingleGroup mode
if (options.optimized) {
std::cout << "Select OptimizedSingleGroupOperation\n";
return profile_convolution<OptimizedSingleGroupOperation>(options);
} else {
std::cout << "Select AnalyticSingleGroupOperation\n";
return profile_convolution<AnalyticSingleGroupOperation>(options);
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
bool notSupported = false;
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution_selecter(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,66 +0,0 @@
#pragma once
#include "predicated_tile_access_iterator_residual_last.h"
#include "predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<
Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass

View File

@ -1,4 +1,3 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
@ -28,9 +27,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
41_multi_head_attention
fused_multihead_attention.cu
43_ell_block_sparse_gemm
ell_block_sparse_gemm.cu
)

View File

@ -0,0 +1,740 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Block-Ell sparse gemm example.
This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation.
Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format.
Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here:
https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell
Whereas matrix B is a dense matrix.
Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices.
First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks,
represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix),
represented by tensor_ell_idx in this example, that represent the column indices of the
corresponding non-zero blocks. All rows in the matrices must have the same number of blocks.
ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in
row-major order.
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
for this example:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
tensor_b - Input dense matrix whose size is (a_cols * n)
tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n)
{a_rows, n, a_cols} - Problem size
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/ell_gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_uncompress.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool reference_check;
int iterations;
int cuda_streams;
int a_rows, n, a_cols;
int a_ell_num_columns;
int a_ell_blocksize;
int a_base;
float alpha;
float beta;
//
// Methods
//
Options():
help(false),
reference_check(true),
iterations(20),
cuda_streams(0),
a_rows(1024),
n(1024),
a_cols(1024),
a_ell_num_columns(512),
a_ell_blocksize(16),
a_base(0),
alpha(1),
beta()
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("a_rows", a_rows, 1024);
cmd.get_cmd_line_argument("n", n, 1024);
cmd.get_cmd_line_argument("a_cols", a_cols, 1024);
cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512);
cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16);
cmd.get_cmd_line_argument("a_base", a_base, 0);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "43_ell_block_sparse_gemm\n\n"
<< " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --a_rows=<int> Sets the number of the rows of the sparse matrix.\n"
<< " --n=<int> Sets the N dimension.\n"
<< " --a_cols=<int> Sets the number of columns of the sparse matrix.\n"
<< " --a_ell_num_columns=<int> Sets the actual number of columns of the Blocked-Ellpack format.\n"
<< " --a_ell_blocksize=<int> Sets the size of the ELL-Block.\n"
<< " --a_base=<int> Sets the base index.\n"
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n"
<< "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
class Testbed {
public:
//
// Type definitions
//
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using ElementAccumulator = typename Gemm::ElementAccumulator;
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using MatrixCoord = typename LayoutC::TensorCoord;
private:
//
// Data members
//
Options options;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_ELL;
uint32_t seed;
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
cutlass::HostTensor<ElementC, LayoutC> tensor_c;
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
cutlass::HostTensor<ElementA, LayoutA> tensor_a_uncompressed;
cutlass::HostTensor<ElementC, LayoutC> reference_d;
cutlass::HostTensor<int32_t, LayoutA> tensor_ell_idx;
public:
//
// Methods
//
Testbed(
Options const &options_,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform,
uint32_t seed_ = 3080
):
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { }
private:
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
void initialize_tensor_(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(
view, seed, Element(), Element(0.5f));
}
else if (dist_kind == cutlass::Distribution::Sequential) {
// Fill with increasing elements
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(1), Element());
} else {
// Fill with all 1s
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(), Element(1));
}
}
/// Initializes data structures
void initialize_() {
tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns));
tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n));
tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n));
tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n));
tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols));
reference_d.resize(cutlass::make_Coord(options.a_rows, options.n));
tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize));
//
// Initialize the problems of the workspace
//
initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021);
initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022);
initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023);
if (init_ELL == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomEllIdx(
tensor_ell_idx.host_view(), seed,
options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize,
options.a_cols / options.a_ell_blocksize);
} else {
for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) {
for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) {
tensor_ell_idx.at({i, j}) = j+3;
}
}
}
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ell_idx.sync_device();
}
/// Verifies the result is a GEMM
bool verify_() {
bool passed = true;
tensor_d.sync_host();
cutlass::uncompress_ell_block_sparse(
tensor_a_uncompressed.host_ref(),
tensor_a.host_ref(),
tensor_ell_idx.host_ref(),
options.a_rows,
options.a_cols,
options.a_ell_num_columns,
options.a_ell_blocksize
);
cutlass::reference::host::Gemm<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute,
ElementAccumulator, typename Gemm::Operator>
reference_gemm;
reference_gemm(
{options.a_rows, options.n, options.a_cols},
options.alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
options.beta,
reference_d.host_ref(),
ElementAccumulator(0)
);
// Reference check
passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view());
if (!passed) {
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
std::stringstream fname;
fname << "error_43_ell_block_sparse_gemm"
<< "mnk_"
<< options.a_rows << "x"
<< options.n << "x"
<< options.a_cols << "_"
<< options.a_ell_num_columns << "_"
<< options.a_ell_blocksize << ".txt";
std::cout << fname.str() << std::endl;
std::ofstream results(fname.str());
results
<< "alpha: " << ElementCompute(options.alpha) << "\n"
<< "beta: " << ElementCompute(options.beta) << "\n"
<< "block size: " << options.a_ell_blocksize << "\n"
<< "\nA:\n" << tensor_a.host_view() << "\n"
<< "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n"
<< "\nB:\n" << tensor_b.host_view() << "\n"
<< "\nC:\n" << tensor_c.host_view() << "\n"
<< "\nD reference:\n" << reference_d.host_view() << "\n"
<< "\nD computed:\n" << tensor_d.host_view() << "\n";
return passed;
}
return passed;
}
public:
/// Returns the number of threadblocks to launch if the kernel can run on the target
/// device. Otherwise, returns zero.
bool sufficient() const {
//
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
throw std::runtime_error("cudaGetDevice() API call failed.");
}
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
if (properties.sharedMemPerBlockOptin < smem_size) {
return false;
}
return true;
}
/// Executes a BlockedEll SpMM kernel and measures runtime.
Result profile() {
Result result;
// Early exit
if (!sufficient()) {
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
result.passed = false;
// Initialize the problem
initialize_();
// Configure the GEMM arguments
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
// Configure GEMM arguments
typename Gemm::Arguments args(
{options.a_rows, options.n, options.a_cols},
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
tensor_ell_idx.device_data(),
options.a_ell_num_columns,
options.a_ell_blocksize,
options.a_base,
epilogue_op
);
// Initialize the GEMM object
Gemm gemm;
result.status = gemm.initialize(args);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Run the BlockedEll SpMM object
result.status = gemm.run();
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Wait for completion
result.error = cudaDeviceSynchronize();
if (result.error != cudaSuccess) {
std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error);
return result;
}
//
// Verify correctness
//
result.passed = true;
if (options.reference_check) {
result.passed = verify_();
}
//
// Warm-up run
//
result.status = gemm.run();
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
gemm();
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)cudaEventDestroy(event);
}
std::cout << std::endl;
std::cout << "ELL Block Sparse GEMM (CUTLASS):\n"
<< "====================================================" << std::endl;
std::cout << std::endl;
std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << " GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
//
// This example requires an NVIDIA Ampere-architecture GPU.
//
std::cout
<< "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or "
<< "later (compute capability 80 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Define the BlockedEll type
//
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
constexpr int32_t kStages = 4;
using Gemm = typename cutlass::gemm::device::EllGemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementOutput,
LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
kStages, kAlignmentA, kAlignmentB>;
//
// Profile it
//
Testbed<Gemm> testbed(options);
if (!testbed.sufficient()) {
std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n";
return 0;
}
Result result = testbed.profile();
if (!result.passed) {
std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,63 @@
This example provides utilities for generating back-to-back (B2B) GEMMs using CUTLASS.
## Quick start
A configuration file containing the GEMMs to be fused together is located in [config.json](config.json). Edit
this to change the configuration that you would like to run.
```shell
cd ir_gen
# Set up basic variables
out_dir=directory_to_emit_files
cutlass_dir=$(pwd)/../../..
config_file=$(pwd)/../config.json
# Generate code for GEMMs described in `config_file`
./generate.sh $config_file $out_dir $cutlass_dir
# Build the generated code
cd $out_dir
mkdir build && cd build
cmake .. -DGPU_ARCHS="75;80"
make -j
# Run the generated code with M=1024 K0=32 and Batch=1
./sample 1024 32 1
```
## Current restrictions
This experimental example has the following restrictions:
1. N tile should not exceed 256, or register spilling will occur.
2. Only FP16 is supported currently
3. Matrix A must be row major, matrix B must be column major, matrices C and D must be row major.
## Copyright
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -0,0 +1,32 @@
{
"0": {
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
"A_format": "Row", "B_format": "Col", "C_format": "Row",
"mnk": [15000, 256, 32],
"epilogue": {
"tp": "LeakyRelu",
"bias": {"addbias": false, "bias_tp": "mat"},
"args": [["float", "leaky_alpha", 1.3]]
}
},
"1": {
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
"A_format": "Row", "B_format": "Col", "C_format": "Row",
"mnk": [15000, 128, 256],
"epilogue": {
"tp": "LeakyRelu",
"bias": {"addbias": false, "bias_tp": "mat"},
"args": [["float", "leaky_alpha", 1.3]]
}
},
"2": {
"A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16",
"A_format": "Row", "B_format": "Col", "C_format": "Row",
"mnk": [15000, 64, 128],
"epilogue": {
"tp": "LeakyRelu",
"bias": {"addbias": false, "bias_tp": "mat"},
"args": [["float", "leaky_alpha", 1.3]]
}
}
}

View File

@ -0,0 +1,154 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
// #include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "fused_bias_act_epilogue.h"
#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h"
#include "output_tile_thread_map_for_fused_bias.h"
#include "default_thread_map_tensor_op_for_fused_bias.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps.
template <
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultFusedBiasActEpilogueTensorOp {
using Shape = Shape_;
using WarpMmaTensorOp = WarpMmaTensorOp_;
static int const kPartitionsK = PartitionsK;
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
//
// Thread map
//
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias<
Shape,
typename WarpMmaTensorOp::Shape,
kPartitionsK,
ElementOutput,
kElementsPerAccess
>::Type;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput
>;
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC>,
cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC> >::type;
//
// Define the epilogue
//
using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputTileIterator,
AccumulatorFragmentIterator,
OutputOp
>;
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,113 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/pitch_linear.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Defines the optimal thread map for TensorOp accumulator layouts
template <
typename ThreadblockShape_,
typename WarpShape_,
int PartitionsK,
typename Element_,
int ElementsPerAccess
>
struct DefaultThreadMapTensorOpForFusedBias {
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
static int const kPartitionsK = PartitionsK;
using Element = Element_;
static int const kElementsPerAccess = ElementsPerAccess;
//
// Definitions
//
struct Detail {
/// Tensor Operations fundamentally perform operations on 8 rows
static int const kTensorOpRows = 8;
static int const kWarpSize = 32;
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<
ThreadblockShape::kM / WarpShape::kM,
ThreadblockShape::kN / WarpShape::kN,
kPartitionsK
>;
/// Number of participating threads
static int const kThreads = WarpCount::kCount * kWarpSize;
};
//
// ThreadMap
//
/// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap
using Type = OutputTileOptimalThreadMapBiasAct <
OutputTileShape<ThreadblockShape::kN, Detail::kTensorOpRows, Detail::WarpCount::kM, 1, 1>,
OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>,
Detail::kThreads,
kElementsPerAccess,
sizeof_bits<Element>::value
>;
};
///////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,222 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator without splitk
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename OutputOp_ ///< Output operator
>
class FusedBiasActEpilogue {
public:
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using OutputOp = OutputOp_;
/// Output layout is always row-major
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
public:
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
public:
/// Constructor
CUTLASS_DEVICE
FusedBiasActEpilogue(
){ }
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
AccumulatorTile & fused_bias_act_accumlators,
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
bool need_bias = output_op.is_source_needed();
if (need_bias)
compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator);
else
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
}
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators);
}
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
AccumulatorTile & fused_bias_act_accumlators,
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
source_iterator.load(source_fragment);
++source_iterator;
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
fused_bias_act_fragment = output_op(accum_fragment, source_fragment);
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
++fused_bias_act_fragment_iterator;
}
}
CUTLASS_DEVICE
void compute_source_no_needed_(
OutputOp const &output_op, ///< Output operator
AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile
AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators);
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment;
fused_bias_act_fragment = output_op(accum_fragment);
fused_bias_act_fragment_iterator.store(fused_bias_act_fragment);
++fused_bias_act_fragment_iterator;
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,311 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/fast_math.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// RowArrangement determines how one or more warps cover a region of consecutive rows.
template <
typename Shape,
int WarpsRemaining,
int ElementsPerAccess,
int ElementSize,
bool Is2dTile
>
struct RowArrangementBiasAct;
/// RowArrangement in which each warp's access is a 1D tiled arrangement.
template <
typename Shape,
int WarpsRemaining,
int ElementsPerAccess,
int ElementSize
>
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
static int const kWarpSize = 32;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kElementSize = ElementSize;
static int const kIterationsRow = 1;
static int const kDeltaRow = 1;
static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
static int const kAccessWidth = kWarpSize;
static int const kAccessRows = 1;
static int const kWarpPartitionsRow = 1;
static int const kWarpPartitionsColumn = WarpsRemaining;
};
/// RowArrangement in which each warp's access is a 2D tiled arrangement.
template <
typename Shape,
int WarpsRemaining,
int ElementsPerAccess,
int ElementSize
>
struct RowArrangementBiasAct<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
static int const kMemoryAccessSize = 4;//128;
static int const kWarpSize = 32;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kElementSize = ElementSize;
struct Detail {
static int const kShapeRow = Shape::kRow / WarpsRemaining;
static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
static int const kTargetMemoryAccessWidth =
kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
};
static int const kAccessWidth =
(Detail::kTargetAccessRows > Detail::kShapeRow ?
kWarpSize / Detail::kShapeRow
: const_min(
Detail::kShapeWidth,
const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
));
static int const kAccessRows =
(Detail::kTargetAccessRows > Detail::kShapeRow ?
Detail::kShapeRow
: const_min(Shape::kRow, kWarpSize / kAccessWidth));
static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
static int const kDeltaRow = kAccessRows;
static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
static int const kWarpPartitionsRow = 1;
static int const kWarpPartitionsColumn = 1;
};
}
////////////////////////////////////////////////////////////////////////////////
/// Template metaprogram for partitioning a 4D space across warps to achieve several performance
/// objectives:
///
/// - coalesced memory accesses in units of 16 Byte lines
/// - minimal address arithmetic
/// - minimal predicate calculations
///
template <
typename Shape_,
typename Count_,
int Threads,
int ElementsPerAccess,
int ElementSize
>
struct OutputTileOptimalThreadMapBiasAct {
using Shape = Shape_;
using Count = Count_;
static int const kWarpSize = 32;
static int const kThreads = Threads;
static int const kWarpCount = kThreads / kWarpSize;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kElementSize = ElementSize;
//
// Metaprogram computation
//
struct Detail {
// Clusters
static int const kIterationsCluster =
((Shape::kCluster > kWarpCount) ?
Shape::kCluster / kWarpCount
: 1);
static int const kDeltaCluster =
((Shape::kCluster > kWarpCount) ?
Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
: 1);
static int const kCompactedDeltaCluster =
((Shape::kCluster > kWarpCount) ?
Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
: 1);
static int const kWarpPartitionsCluster =
((Shape::kCluster > kWarpCount) ?
kWarpCount
: kWarpCount / Shape::kCluster);
static int const kWarpsRemainingForGroups =
((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
// Groups
static int const kIterationsGroup =
((Shape::kGroup > kWarpsRemainingForGroups) ?
Shape::kGroup / kWarpsRemainingForGroups
: 1);
static int const kDeltaGroup =
((Shape::kGroup > kWarpsRemainingForGroups) ?
Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
: 1);
static int const kCompactedDeltaGroup =
((Shape::kGroup > kWarpsRemainingForGroups) ?
Shape::kRow * Shape::kGroup / kIterationsGroup
: 1);
static int const kWarpPartitionsGroup =
((Shape::kGroup > kWarpsRemainingForGroups) ?
1
: kWarpsRemainingForGroups / Shape::kGroup);
static int const kWarpsRemainingForRows =
((Shape::kGroup > kWarpsRemainingForGroups) ?
1
: kWarpsRemainingForGroups / Shape::kGroup);
// Rows
using RowArrangement = detail::RowArrangementBiasAct<
Shape,
kWarpsRemainingForRows,
kElementsPerAccess,
kElementSize,
(Shape::kRow > kWarpsRemainingForRows)
>;
// Warp partitions
using WarpPartitions = OutputTileShape<
RowArrangement::kWarpPartitionsColumn,
RowArrangement::kWarpPartitionsRow,
kWarpPartitionsGroup,
kWarpPartitionsCluster,
1>;
static int const kAccessWidth = RowArrangement::kAccessWidth;
static int const kAccessRows = RowArrangement::kAccessRows;
};
//
// Output
//
using Iterations = OutputTileShape<
Detail::RowArrangement::kIterationsColumn,
Detail::RowArrangement::kIterationsRow,
Detail::kIterationsGroup,
Detail::kIterationsCluster,
1>;
using Delta = OutputTileShape<
Detail::RowArrangement::kDeltaColumn,
Detail::RowArrangement::kDeltaRow,
Detail::kDeltaGroup,
Detail::kDeltaCluster,
1>;
/// Initial offset function
CUTLASS_HOST_DEVICE
static MatrixCoord initial_offset(int thread_idx) {
int warp_idx = thread_idx / kWarpSize;
int lane_idx = thread_idx % kWarpSize;
// Compute warp location
int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
int row_idx = residual_group / Detail::WarpPartitions::kRow;
int col_idx = residual_group % Detail::WarpPartitions::kRow;
// Compute per-lane offset
int lane_row_offset = lane_idx / Detail::kAccessWidth;
int lane_col_offset = lane_idx % Detail::kAccessWidth;
// Compute coordinate in output space
int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
int group_offset = group_idx * Shape::kRow * Count::kRow;
int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
return MatrixCoord(
cluster_offset + group_offset + row_offset + lane_row_offset,
(column_offset + lane_col_offset) * kElementsPerAccess
);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -0,0 +1,189 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile
that participate in one warp-level store operation.
Typically, the accumulator tile is the largest single block of register-backed storage
within the kernel. Storing it to memory is best accomplished by partitioning it into
smaller tiles and storing these sequentially.
Round trips through shared memory during the Epilogue phase require partitioning, as
shared memory capacity is typically insufficient for a threadblock's total accumulator
size.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/epilogue/warp/tensor_op_policy.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace warp {
////////////////////////////////////////////////////////////////////////////////
///
template <
typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape)
typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape)
typename OperatorElementC, ///< matrix multiply operation data type (concept: data type)
typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array)
typename Layout ///< target shared memory layout
>
class FusedBiasActFragmentIteratorTensorOp;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for row-major shared memory
template <
typename WarpShape_, ///< shape of the warp-level GEMM tile
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type)
typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array)
>
class FusedBiasActFragmentIteratorTensorOp<WarpShape_, OperatorShape_, OperatorElementC_, OperatorFragmentC_, layout::RowMajor> {
public:
using WarpShape = WarpShape_;
using OperatorShape = OperatorShape_;
using OperatorElementC = OperatorElementC_;
using OperatorFragmentC = OperatorFragmentC_;
using Layout = layout::RowMajor;
using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
/// This is the fragment size produced by one access of the iterator.
using Fragment = Array<
OperatorElementC,
Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
/// This is the complete warp-level accumulator tile.
using AccumulatorTile = Array<
OperatorElementC,
OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
using OutputAccumulatorTile = AccumulatorTile;
/// Number of times this iterator can be incremented
static int const kIterations = Policy::kIterations;
private:
/// Internal access type
using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
private:
//
// Data members
//
/// Accumulator tile
AccessType *accumulators_;
/// Internal index
int index_;
public:
/// Constructs an iterator
CUTLASS_HOST_DEVICE
FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum):
accumulators_(reinterpret_cast<AccessType *>(&accum)),
index_(0) {
}
/// Increments
CUTLASS_HOST_DEVICE
FusedBiasActFragmentIteratorTensorOp &operator++() {
++index_;
return *this;
}
/// Decrements
CUTLASS_HOST_DEVICE
FusedBiasActFragmentIteratorTensorOp &operator--() {
--index_;
return *this;
}
/// Loads a fragment from the referenced part of the accumulator tile
CUTLASS_HOST_DEVICE
void load(Fragment &frag, int index_offset = 0) const {
int index = index_ + index_offset;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
int accumulator_access_offset =
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
frag_ptr[n] = accumulators_[accumulator_access_offset];
}
}
/// Stores a fragment from the referenced part of the accumulator tile
CUTLASS_HOST_DEVICE
void store(Fragment &frag, int index_offset = 0) const {
int index = index_ + index_offset;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
int accumulator_access_offset =
index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
accumulators_[accumulator_access_offset] = frag_ptr[n];
}
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,427 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace warp {
////////////////////////////////////////////////////////////////////////////////
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Size of the accumulation tile shape (concept: MatrixShape)
typename AccumulatorShape_,
/// KBlocks columns to compute residual
int KBlocksColumn_,
/// Accumulator Element type
typename ElementAccumulator_,
/// Element type
typename Element_,
/// Layout of operand in memory
typename Layout_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Whether beta is zero
bool IsBetaZero_ >
class MmaTensorOpPureFragmentIterator;
// Partial specialization for col-major accumulator tile
// And Element type is the same as Accumulator Element type
template <
/// Shape of warp tile to load (concept: MatrixShape)
typename Shape_,
/// Shape of the warp accumulation tile (concept: MatrixShape)
typename AccumulatorShape_,
/// KBlocks columns to compute residual
int KBlocksColumn_,
/// Element type
typename Element_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_>
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
cutlass::layout::ColumnMajor,
InstructionShape_, true> {
public:
/// Shape of warp tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Shape of the warp accumulation tile (concept: MatrixShape)
using AccumulatorShape = AccumulatorShape_;
/// KBlocks columns to compute residual
static int const kKBlockColumn = KBlocksColumn_;
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = cutlass::layout::ColumnMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Whether beta is zero
static bool const IsBetaZero = true;
/// Number of participating threads
static int const kThreads = 32;
/// Internal structure of iterator - made public to enable introspection
struct Policy {
static_assert(
!(Shape::kRow % InstructionShape::kM) &&
!(Shape::kColumn % InstructionShape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
static_assert(
!(AccumulatorShape::kRow % Shape::kRow) &&
!(AccumulatorShape::kColumn % Shape::kColumn),
"Shape of Warp Accumulator must be divisible by warp shape.");
static_assert(
!(kKBlockColumn % Shape::kColumn),
"KBlock size must be divisible by warp shape.");
/// Number of times this iterator can be incremented
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
};
private:
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
/// Number of mma operations performed by a warp
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
Shape::kColumn / InstructionShape::kN>;
/// Number of mma operations performed by the entire accumulator
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
AccumulatorShape::kColumn / InstructionShape::kN>;
/// Number of K iterations
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
static int const kResidualIndex = kResidualColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
/// This is the fragment size produced by one access of the iterator.
using Fragment = Array<Element, Shape::kCount / kThreads>;
/// Accumulator Fragment object
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
private:
/// Internal access type
using AccessType = Array<Element, kElementsPerAccess>;
private:
//
// Data members
//
/// Accumulator tile
AccessType const *accumulators_;
/// Internal index
int index_;
/// Used to access residual tile first
bool is_residual_tile_;
public:
/// Constructs an iterator
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
index_(0), is_residual_tile_(true) {}
/// Add offset
CUTLASS_HOST_DEVICE
void add_offset(int index_offset) {
index_ += index_offset;
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
is_residual_tile_ = false;
}
}
/// Increments
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator &operator++() {
add_offset(1);
return *this;
}
/// Decrements
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator &operator--() {
add_offset(-1);
return *this;
}
/// Loads a fragment from the referenced part of the accumulator tile
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
AccessType src_fragment;
src_fragment.clear();
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
* MmaIterations::kColumn;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; n++) {
for (int m = 0; m < MmaIterations::kRow; m++) {
int accumulator_access_offset =
(n + index_n) * AccumulatorIterations::kRow + m + index_m;
frag_ptr[n * MmaIterations::kRow + m].clear();
if(!(is_residual_tile_ && index_ >= kResidualIndex))
frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset];
// frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment);
}
}
}
};
// Partial specialization for row-major accumulator tile
template <
/// Shape of warp tile to load (concept: MatrixShape)
typename Shape_,
/// Shape of the warp accumulation tile (concept: MatrixShape)
typename AccumulatorShape_,
/// KBlocks columns to compute residual
int KBlocksColumn_,
/// Accumulator Element type
typename ElementAccumulator_,
/// Element type
typename Element_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_>
class MmaTensorOpPureFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
cutlass::layout::RowMajor,
InstructionShape_, true> {
public:
/// Shape of warp tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Shape of the warp accumulation tile (concept: MatrixShape)
using AccumulatorShape = AccumulatorShape_;
/// KBlocks columns to compute residual
static int const kKBlockColumn = KBlocksColumn_;
/// Accumulator Element type
using ElementAccumulator = ElementAccumulator_;
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Whether beta is zero
static bool const IsBetaZero = true;
/// Number of participating threads
static int const kThreads = 32;
/// Internal structure of iterator - made public to enable introspection
struct Policy {
static_assert(
!(Shape::kRow % InstructionShape::kM) &&
!(Shape::kColumn % InstructionShape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
static_assert(
!(AccumulatorShape::kRow % Shape::kRow) &&
!(AccumulatorShape::kColumn % Shape::kColumn),
"Shape of Warp Accumulator must be divisible by warp shape.");
static_assert(
!(kKBlockColumn % Shape::kColumn),
"KBlock size must be divisible by warp shape.");
/// Number of times this iterator can be incremented
static int const kIterations = AccumulatorShape::kCount / Shape::kCount;
};
private:
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
/// Number of mma operations performed by a warp
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
Shape::kColumn / InstructionShape::kN>;
/// Number of mma operations performed by the entire accumulator
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
AccumulatorShape::kColumn / InstructionShape::kN>;
/// Number of K iterations
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
static int const kResidualIndex = kResidualColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
/// This is the fragment size produced by one access of the iterator.
using Fragment = Array<Element, Shape::kCount / kThreads>;
/// Accumulator Fragment object
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
private:
/// Internal access type
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentAccessType = Array<Element, kElementsPerAccess>;
private:
//
// Data members
//
/// Accumulator tile
AccessType const *accumulators_;
/// Internal index
int index_;
/// Used to access residual tile first
bool is_residual_tile_;
public:
/// Constructs an iterator
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum)
: accumulators_(reinterpret_cast<AccessType const *>(&accum)),
index_(0), is_residual_tile_(true) {}
/// Add offset
CUTLASS_HOST_DEVICE
void add_offset(int index_offset) {
index_ += index_offset;
if(is_residual_tile_ && index_ >= kKBlockColumnIterations) {
index_ = index_ - kKBlockColumnIterations + kResidualIndex;
is_residual_tile_ = false;
}
}
/// Increments
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator &operator++() {
add_offset(1);
return *this;
}
/// Decrements
CUTLASS_HOST_DEVICE
MmaTensorOpPureFragmentIterator &operator--() {
add_offset(-1);
return *this;
}
/// Loads a fragment from the referenced part of the accumulator tile
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
FragmentAccessType src_fragment;
src_fragment.clear();
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
* MmaIterations::kColumn;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; m++) {
for (int n = 0; n < MmaIterations::kColumn; n++) {
int accumulator_access_offset =
(m + index_m) * AccumulatorIterations::kColumn + n + index_n;
frag_ptr[m * MmaIterations::kColumn + n].clear();
if(!(is_residual_tile_ && index_ >= kResidualIndex))
frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]);
}
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,129 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import gen_turing_and_volta as api_generator
import gen_sample as sample_creater
import gen_cmake as cmake_creater
import gen_verify as verify_creater
import gen_device as b2b_fused_generator
import replace_fix_impl_header
import argparse
import os
import json
parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels")
parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate")
parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name")
parser.add_argument("--output-dir", default="", help="Specifies the output dir")
parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir")
parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.")
args = parser.parse_args()
gen_name = args.gen_name
cutlass_deps_dir = args.cutlass_dir
output_dir = args.output_dir
output_dir += "/"
cutlass_deps_root = args.gen_include_cutlass_dir
if cutlass_deps_root == '':
cutlass_deps_root = cutlass_deps_dir + "/include/"
cutlass_deps_root +='/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(output_dir + "/" + "auto_gen"):
os.mkdir(output_dir + "/" + "auto_gen")
if not os.path.exists(output_dir + "/" + "fixed_impl"):
os.mkdir(output_dir + "/" + "fixed_impl" )
if not os.path.exists(output_dir + "/" + "sample"):
os.mkdir(output_dir + "/" + "sample" )
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"):
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device")
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"):
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel")
if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"):
os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock")
with open(args.config_file, 'r') as infile:
gemm_info_dict = json.load(infile)
keys = sorted(gemm_info_dict.keys())
fuse_gemm_info = [gemm_info_dict[k] for k in keys]
for_cutlass_gen_user_include_header_file = [
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
]
for_fused_wrapper = [
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h",
cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h",
"auto_gen/device/" + gen_name + ".h",
cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h",
cutlass_deps_root + "cutlass/cutlass.h",
]
# Copy fixed implementation to the output directory
fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root)
fix_impl.gen_code()
auto_gen_output_dir = output_dir + "/auto_gen/"
project_root = ""
turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir)
turing_plus.gen_code(75, 'hmma1688', False)
api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
api.gen_code()
# Generate C++ sample
os.system("cp ../leaky_bias.h " + output_dir + "/sample/")
os.system("cp ../utils.h " + output_dir + "/sample/")
sample_dir = output_dir + "/sample/"
sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir)
sample.gen_cpp_sample()
cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir)
cmake_gen.gen_code()
verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir)
verify.gen_code()

View File

@ -0,0 +1,131 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
class gen_build_sys:
def __init__(self, cutlass_deps_dir, output_dir = "../"):
self.output_dir = output_dir
self.cutlass_deps_dir = cutlass_deps_dir
def gen_top(self):
code = ""
code += '''\
# Auto Generated code - Do not edit.
cmake_minimum_required(VERSION 3.8)
project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA)
find_package(CUDAToolkit)
set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}})
set(CUTLASS_PATH \"{cutlass_deps_dir}/include\")
set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\")
list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}})
'''.format(cutlass_deps_dir=self.cutlass_deps_dir)
code += '''\
set(GPU_ARCHS \"\" CACHE STRING
\"List of GPU architectures (semicolon-separated) to be compiled for.\")
if(\"${GPU_ARCHS}\" STREQUAL \"\")
set(GPU_ARCHS \"70\")
endif()
foreach(arch ${GPU_ARCHS})
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\")
if(SM STREQUAL 70 OR SM STREQUAL 75)
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\")
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\")
endif()
endforeach()
set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\")
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\")
set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\")
set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\")
set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\")
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if(CMAKE_CXX_STANDARD STREQUAL \"11\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")
endif()
set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\")
set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\")
set(COMMON_HEADER_DIRS
${PROJECT_SOURCE_DIR}
${CUDAToolkit_INCLUDE_DIRS}
)
set(COMMON_LIB_DIRS
${CUDAToolkit_LIBRARY_DIR}
)
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH})
list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH})
'''
code += '''\
include_directories(
${COMMON_HEADER_DIRS}
)
link_directories(
${COMMON_LIB_DIRS}
)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-DGOOGLE_CUDA=1)
add_executable(sample
sample/sample.cu
one_api.cu
)
target_link_libraries(sample PRIVATE
-lcudart
-lnvToolsExt
${CMAKE_THREAD_LIBS_INIT}
)
if(NOT DEFINED LIB_INSTALL_PATH)
set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR})
endif()
'''
return code
def gen_code(self):
top_code = self.gen_top()
with open(self.output_dir + "CMakeLists.txt", "w") as f:
f.write(top_code)

View File

@ -0,0 +1,120 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ast
fuse_gemm_info = [
{
'epilogue': {
'tp': 'LeakyRelu', #'CustomizedLeaky_RELU'
'bias': {'addbias': False, 'bias_tp': 'mat'},
'args': [('float', 'leaky_alpha', 1.3), ],
'func': '''
y = max(leaky_alpha * x, x)
y = y * x
'''
}
},
]
class AnalysisNodeVisitor(ast.NodeVisitor):
def visit_Import(self,node):
ast.NodeVisitor.generic_visit(self, node)
def visit_ImportFrom(self,node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Assign(self,node):
print('Node type: Assign and fields: ', node._fields)
# print('Node type: Assign and targets value: ', node.targets, node.value)
ast.NodeVisitor.generic_visit(self, node)
def visit_BinOp(self, node):
print('Node type: BinOp and fields: ', node._fields)
print('node op: ', type(node.op).__name__)
ast.NodeVisitor.generic_visit(self, node)
def visit_Expr(self, node):
print('Node type: Expr and fields: ', node._fields)
ast.NodeVisitor.generic_visit(self, node)
def visit_Num(self,node):
print('Node type: Num and fields: ', node._fields)
print('Node type: Num: ', node.n)
def visit_Name(self,node):
print('Node type: Name and fields: ', node._fields)
print('Node type: Name and fields: ', type(node.ctx).__name__, node.id)
ast.NodeVisitor.generic_visit(self, node)
def visit_Str(self, node):
print('Node type: Str and fields: ', node._fields)
class CodeVisitor(ast.NodeVisitor):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
node.op = ast.Sub()
self.generic_visit(node)
def visit_Assign(self, node):
print('Assign %s' % node.value)
self.generic_visit(node)
def visit_Name(self, node):
print("Name:", node.id)
self.generic_visit(node)
def visit_FunctionDef(self, node):
print('Function Name:%s'% node.name.op)
self.generic_visit(node)
func_log_stmt = ast.Print(
dest = None,
values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
nl = True,
lineno = 0,
col_offset = 0,
)
node.body.insert(0, func_log_stmt)
visitor = AnalysisNodeVisitor()
code = \
'''
a=max(leaky_alpha * x, x +1)
'''
visitor.visit(ast.parse(code))

View File

@ -0,0 +1,477 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from typing import *
import helper
import gen_ir
import gen_kernel as gen_ker
class gen_device:
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.raw_gemm_info = fuse_gemm_info
self.b2b_num = len(fuse_gemm_info)
self.user_header_file = user_header_file
self.args = {}
# device arg struct memebr
self.arg_member = []
self.gen_class_name = gen_class_name
self.gen_kernel_name = gen_class_name + "Kernel"
self.tempalte_args = []
self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
self.file_name = output_dir + "/device/" +gen_class_name +".h"
self.sample_dir = output_dir
self.cutlass_deps_root = cutlass_deps_root
self.project_root = project_root
self.this_file_root = output_dir + "/device/"
self.first_use_1stage = False
## gen kernel
self.gen_kernel = gen_ker.gen_kernel(self.tempalte_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
def __check_arg_type(self, temp_arg):
if temp_arg in self.__tempalate_arg_list.keys():
return self.__tempalate_arg_list[temp_arg]
find_sub = False
for candidate_arg in self.__tempalate_arg_list.keys():
if (temp_arg.find(candidate_arg) != -1):
return self.__tempalate_arg_list[candidate_arg]
return 'typename'
# def gen_B2b2bGemm_class():
def set_arch(self, sm_cap, mma_tp):
if sm_cap == 75 or sm_cap == 80 or sm_cap == 86:
self.arch = "cutlass::arch::Sm" + str(sm_cap)
if mma_tp is 'hmma1688':
self.mma_shape = [16, 8, 8]
self.mma_tp = 'hmma'
elif mma_tp is 'imma8816':
self.mma_tp = 'imma'
self.mma_shape = [8, 8, 16]
else:
return 0
def gen_include_header(self):
code = '''\
/* Auto Generated code - Do not edit.*/
#pragma once
#include \"{cutlass_root}cutlass/cutlass.h\"
#include \"{cutlass_root}cutlass/numeric_types.h\"
#include \"{cutlass_root}cutlass/arch/arch.h\"
#include \"{cutlass_root}cutlass/device_kernel.h\"
#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\"
#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\"
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\"
#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\"
#include \"{project_root}../kernel/b2b_gemm.h\"
#include \"{project_root}../kernel/default_b2b_gemm.h\"
'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root)
include_user_header = ""
for header in self.user_header_file:
include_user_header += "#include \"" + header + "\"\n"
return code + include_user_header
def gen_code(self, sm_cap, mma_tp, ifprint = True):
self.set_arch(sm_cap, mma_tp)
self.update_b2b_args()
print(self.fuse_gemm_info)
self.update_b2b_class_template_args()
func_code = self.gen_all_func()
member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
gen_code = gen_ir.gen_template_class(self.gen_class_name, self.tempalte_args, func_code + member_var_code)
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
if ifprint:
print(code)
print("[INFO]: Gen device code output Dir: is ", self.file_name)
with open(self.file_name, 'w+') as f:
f.write(code)
gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage)
print(gen_kernel)
def update_b2b_class_template_args(self):
for arg in self.args.keys():
self.tempalte_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
def update_b2b_args(self):
self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp'])
self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format'])
cnt = 0
warp_M_tile = 32
# Determine maxmimum N_tile
Max_Ntile = 0
for layer in self.fuse_gemm_info:
n_tile = layer['mnk'][1]
if n_tile > Max_Ntile:
Max_Ntile = n_tile
if Max_Ntile >= 256:
warp_M_tile = 16
stages_temp = []
for layer in self.fuse_gemm_info:
cnt_str = str(cnt)
B_tp_str= 'ElementB' + cnt_str
B_format_str = 'LayoutB' + cnt_str
C_tp_str= 'ElementC' + cnt_str
C_format_str = 'LayoutC' + cnt_str
Acc_str = 'ElementAccumulator' + cnt_str
self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp'])
self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format'])
self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp'])
self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format'])
self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp'])
mnk = layer['mnk'][:]
tile_mnk = mnk[:]
tile_mnk[2] = 32 # force the ktile is 32
#N tile gen
if mnk[1] > 1024:
assert(0)
elif mnk[1] > 512:
tile_mnk[1] = 1024
elif mnk[1] > 256:
tile_mnk[1] = 512
elif mnk[1] > 128:
tile_mnk[1] = 256
elif mnk[1] > 64:
tile_mnk[1] = 128
elif mnk[1] > 32:
tile_mnk[1] = 64
else :
tile_mnk[1] = 32
if tile_mnk[1] == 512:
stages_temp.append(1)
else:
stages_temp.append(2)
tile_mnk[0] = 4 * warp_M_tile
epilogue_setted_type = helper.get_epilogue_tp(layer)
cutlass_epilogue_name = "LinearCombinationRelu"
if epilogue_setted_type.lower() == 'leakyrelu':
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
elif epilogue_setted_type.lower() == 'identity':
cutlass_epilogue_name = "LinearCombination"
epilogue_str = 'EpilogueOutputOp' + cnt_str
if cnt != len(self.fuse_gemm_info) - 1:
n = layer['mnk'][1]
Fragments = tile_mnk[1] // 8 * 2
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "<ElementC0_, " + str(Fragments) +", ElementAccumulator0_, ElementAccumulator0_>"
else:
n = layer['mnk'][1]
n_mod_8 = n % 4
N_align_elements = 1
if n_mod_8 == 0:
N_align_elements = 8
elif n_mod_8 == 4:
N_align_elements = 4
elif n_mod_8 == 2 or n_mod_8 == 6:
N_align_elements = 2
self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<ElementC0_, " + str(N_align_elements) + ", ElementAccumulator0_, ElementAccumulator0_>"
ThreadBlockShape_str = 'ThreadblockShape' + cnt_str
self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
WarpShape_str = 'WarpShape' + cnt_str
tile_mnk[0] = warp_M_tile
self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk)
cnt += 1
self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp'])
self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format'])
self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape)
self.args['OperatorClass'] = 'arch::OpClassTensorOp'
self.args['ArchTag'] = self.arch
self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle'
for i in range(self.b2b_num):
self.args[helper.var_idx('Stages', i)] = "2"
self.args['AlignmentA'] = str(8)
self.args['AlignmentB'] = str(8)
self.args['SplitKSerial'] = 'false'
self.args['Operator'] = 'typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB0_, ElementC0_, ElementAccumulator0_>::Operator'
self.args['IsBetaZero'] = 'false'
def gen_using_kernel(self):
code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n"
code += " " + "ElementA,\n"
code += " " + "LayoutA,\n"
for i in range(self.b2b_num):
code += " " + helper.var_idx("ElementB", i) + ",\n"
code += " " + helper.var_idx("LayoutB", i) + ",\n"
code += " " + helper.var_idx("ElementC", i) + ",\n"
code += " " + helper.var_idx("LayoutC", i) + ",\n"
code += " " + helper.var_idx("ElementAccumulator", i) + ",\n"
code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n"
code += " " + helper.var_idx("ThreadblockShape", i) + ",\n"
code += " " + helper.var_idx("WarpShape", i) + ",\n"
code += " " + "ElementD,\n"
code += " " + "LayoutD,\n"
code += " " + "InstructionShape,\n"
code += " " + "OperatorClass,\n"
code += " " + "ArchTag,\n"
code += " " + "ThreadblockSwizzle,\n"
for i in range(self.b2b_num):
code += " " + helper.var_idx("Stages", i) + ",\n"
code += " " + "AlignmentA,\n"
code += " " + "AlignmentB,\n"
code += " " + "SplitKSerial,\n"
code += " " + "Operator,\n"
code += " " + "IsBetaZero_\n"
code += ">::B2bGemmKernel;\n\n"
return code
def gen_args(self):
def gen_arg_member(b2b_num):
data_members = []
for i in range(b2b_num):
member_type = "GemmCoord"
member_name = "problem_size_" + str(i)
data_members.append((member_type, member_name))
member_type = "TensorRef<ElementA const, LayoutA>"
member_name = "ref_A0"
data_members.append((member_type, member_name))
for i in range(b2b_num):
member_type = "TensorRef<ElementB" + str(i) + " const, LayoutB" + str(i) +">"
member_name = "ref_B" + str(i)
data_members.append((member_type, member_name))
member_type = "TensorRef<ElementC" + str(i) + " const, LayoutC" + str(i) +">"
member_name = "ref_C" + str(i)
data_members.append((member_type, member_name))
member_type = "TensorRef<ElementD, LayoutD>"
member_name = helper.var_idx("ref_D", b2b_num - 1)
data_members.append((member_type, member_name))
for i in range(b2b_num):
member_type = "typename EpilogueOutputOp" + str(i) + "::Params"
member_name = "epilogue" + str(i)
data_members.append((member_type, member_name))
data_members.append(('int', 'batch_count'))
return data_members
def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value):
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
gen_ir.indentation + struct_name + " (): "
for i in range(inital_param_num):
final_param = ','
if i == inital_param_num - 1:
final_param = '{ }'
constructs_code += data_members[i][1] + inital_value + final_param
constructs_code += "\n"
return constructs_code
def gen_arg_struct_ctor(struct_name, data_members):
constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \
gen_ir.indentation + struct_name + " (\n"
cnt = 0
param_num = len(data_members)
for param in data_members:
final = ',\n'
if cnt == param_num - 1:
final = '\n):\n'
constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final
cnt += 1
cnt = 0
for param in data_members:
final = '),\n'
if cnt == param_num - 1:
final = ") { }\n"
constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final
cnt += 1
constructs_code += "\n"
return constructs_code
# (variable type, variable name)
struct_member = gen_arg_member(self.b2b_num)
self.arg_member = struct_member
codeBody = ""
for each_member in struct_member:
codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n"
codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n"
codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n"
struct_code = gen_ir.gen_struct("Arguments", codeBody)
return struct_code
def gen_func_constructs(self):
code = self.gen_class_name +"() {}"
return code
def gen_func_initialize(self):
code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \
"// Determine grid shape\n" + \
"ThreadblockSwizzle threadblock_swizzle;\n" + \
"cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \
" args.problem_size_0, \n" + \
" { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \
" args.batch_count);\n" + \
"// Initialize the Params structure\n" + \
"params_ = typename B2bGemmKernel::Params{\n"
for i in range(self.b2b_num):
code += helper.var_idx(" args.problem_size_", i) + ",\n"
code += " grid_shape,\n" + \
" args.ref_A0.non_const_ref(),\n"
for i in range(self.b2b_num):
code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n"
code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n"
code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n"
for i in range(self.b2b_num):
code += helper.var_idx(" args.epilogue", i) + ",\n"
code += " args.batch_count\n"
code += "};\n" + \
"return Status::kSuccess;\n" + \
"}\n"
return code
def gen_func_run(self):
code = "Status run(cudaStream_t stream = nullptr) {\n" + \
"\n" + \
" ThreadblockSwizzle threadblock_swizzle;\n" + \
"\n" + \
" dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \
" dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \
"\n" + \
" cudaError_t result;\n" + \
"\n" + \
" int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \
" if (smem_size >= (48 << 10)) {\n" + \
" result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \
"\n" + \
" if (result != cudaSuccess) {\n" + \
" return Status::kErrorInternal;\n" + \
" }\n" + \
"\n" + \
" result = cudaFuncSetAttribute(\n" + \
" Kernel<B2bGemmKernel>,\n" + \
" cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \
"\n" + \
" if (result != cudaSuccess) {\n" + \
" return Status::kErrorInternal;\n" + \
" }\n" + \
" }\n" + \
" cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
" result = cudaGetLastError();\n" + \
" return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \
" }\n"
return code
def gen_func_operator(self):
opeartor_with_arg_code = "Status operator()(\n" + \
" Arguments const &args,\n" + \
" void *workspace = nullptr,\n" + \
" cudaStream_t stream = nullptr) {\n" + \
" Status status = initialize(args, workspace);\n" + \
" \n" + \
" if (status == Status::kSuccess) {\n" + \
" status = run(stream);\n" + \
" }\n" + \
" return status;\n" + \
"}\n"
operator_code = "Status operator()(\n" + \
" cudaStream_t stream = nullptr) {\n" + \
" Status status = run(stream);\n" + \
" return status;\n" + \
"}\n"
return opeartor_with_arg_code + "\n" + operator_code
def gen_all_func(self):
return self.gen_using_kernel() + "\n" + \
self.gen_args() + "\n" + \
self.gen_func_constructs() + "\n" + \
self.gen_func_initialize() + "\n" + \
self.gen_func_run() + "\n" + \
self.gen_func_operator()

View File

@ -0,0 +1,249 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import helper
indentation = " "
def append_word(word):
code = ""
code += word
code += " "
return code
def gen_namespace(namespace, codeBody):
code_gen = "namespace " + namespace + " {\n"
code_gen += codeBody
code_gen += "} // namespace " + namespace + "\n"
return code_gen
def gen_expression(type, lval, rval = None):
code_gen = ""
code_gen += append_word(type)
code_gen += append_word(lval)
if rval is not None:
code_gen += append_word("=")
code_gen += append_word(rval)
return code_gen
def gen_class(name, codeBody, inheritance_code = None):
code_gen = ""
if inheritance_code is None:
code_gen = "class " + name + "{\n"
else:
code_gen = "class " + name + " : "+ inheritance_code + "{\n"
code_gen += codeBody
code_gen += "}; // class " + name + "\n"
return code_gen
def gen_struct(name, codeBody, specialized = None):
specialized_code = ""
if specialized is not None:
specialized_code = "<" + specialized + ">"
code_gen = "struct " + name + specialized_code + "{\n"
code_gen += codeBody
code_gen += "}; // struct " + name + "\n"
return code_gen
def gen_template_arg(arg_type, arg_name, default_val = None):
rval = None
if default_val is not None:
rval = str(default_val)
arg_typename = ""
if arg_type is int:
arg_typename = "int"
elif arg_type is bool:
arg_typename = "bool"
else:
arg_typename = "typename"
internal_arg_name = arg_name + "_"
code_gen = indentation
code_gen += gen_expression(arg_typename, internal_arg_name, rval)
return code_gen
def gen_template_args(args, set_default = True):
arg_len = len(args)
cnt = 1
code_gen = ""
for arg_tuple in args:
arg_type = arg_tuple[0]
arg_name = arg_tuple[1]
arg_default_val = None
if len(arg_tuple) == 3 and set_default:
arg_default_val = arg_tuple[2]
code_gen += gen_template_arg(arg_type, arg_name, arg_default_val)
if cnt != arg_len:
code_gen += ",\n"
cnt += 1
return code_gen
def gen_template_head(args, set_default = True):
code_gen = "template <\n"
code_gen += gen_template_args(args, set_default)
code_gen += ">\n"
return code_gen
def export_template_args(args):
code_gen = "public:\n"
for arg_tuple in args:
code_gen += indentation
arg_type = arg_tuple[0]
arg_name = arg_tuple[1]
internal_arg_name = arg_name + "_"
typename = ""
if arg_type is int:
typename = "static int const"
elif arg_type is bool:
typename = "static bool const"
else:
typename = "using"
code_gen += gen_expression(typename, arg_name, internal_arg_name)
code_gen += ";\n"
return code_gen
def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None):
code_gen = ""
code_gen += gen_template_head(args, set_default)
code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code)
return code_gen
def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True):
code_gen = ""
code_gen += gen_template_head(args, set_default)
code = export_template_args(args) + codeBody
if export_args is False:
code = codeBody
code_gen += gen_struct(struct_name, code , speicalized)
return code_gen
def gen_declare_template_struct(name, *params):
code = name + "<"
cnt = 0
param_num = len(params)
for param in params:
final = ", "
if cnt == param_num - 1:
final = ""
code += param + final
cnt += 1
code += ">;\n"
return code
def filtered_param(params, name_and_value_pair, keep_ = False):
rtn_template_args = []
speicalized_template_args = []
for param in params:
param_name = ""
if len(param) >= 1:
param_name = param[1]
else:
param_name = param[0]
hit_flag = False
set_value = ""
for n_v_pair in name_and_value_pair:
filter_name = n_v_pair[0]
set_value = n_v_pair[1]
if param_name == (filter_name + "_") or param_name == filter_name :
hit_flag = True
break
if hit_flag is False:
rtn_template_args.append(param)
if hit_flag is True:
speicalized_template_args.append(set_value)
else:
if keep_ is True:
speicalized_template_args.append(param_name + "_")
else:
speicalized_template_args.append(param_name)
specialized_template_arg_str = helper.list_2_string(speicalized_template_args)
return rtn_template_args, specialized_template_arg_str
def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True):
code = "void " + func_name + "(\n"
for arg in arg_lists:
arg_tp = arg[0]
arg_nm = arg[1]
code += " " + arg_tp + " " + arg_nm + ",\n"
code += "cudaStream_t stream)"
if only_declare :
return code
code += "{\n"
code += code_body + "\n"
code += "}\n"
return code
def indent_level(code, level = 0):
rtn_code = ""
for i in range(level):
rtn_code += " "
rtn_code += code
return rtn_code

View File

@ -0,0 +1,476 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import gen_ir
import helper
import gen_threadblock as gen_tb
class gen_default_Gemm:
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
self.gen_class_name = "B2bGemm"
self.template_param = template_param
self.b2b_num = b2b_num
self.cutlass_deps_root = cutlass_deps_root
self.project_root = project_root
def gen_B2bMma(self, specialized_template_args):
code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n"
code += specialized_template_args
code += ">::ThreadblockB2bMma;\n"
# print(code)
return code
def gen_epilogue(self):
epilogue_code = ""
epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n"
epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n"
epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n"
epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n"
epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n"
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n"
epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n"
epilogue_code += ">::Epilogue;\n"
epilogue_code += "using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;\n\n"
return epilogue_code
def gen_include_header(self):
code = '''
/* Auto Generated code - Do not edit.*/
#pragma once
#include \"{cutlass_dir}cutlass/cutlass.h\"
#include \"{cutlass_dir}cutlass/layout/matrix.h\"
#include \"{cutlass_dir}cutlass/numeric_types.h\"
#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\"
#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\"
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\"
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\"
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\"
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\"
#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\"
#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\"
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\"
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\"
#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\"
#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\"
#include \"../kernel/b2b_gemm.h\"
#include \"../threadblock/default_b2b_mma.h\"
'''.format(cutlass_dir=self.cutlass_deps_root)
return code
def gen_code(self):
gen_using = ''
# Generate default template struct
gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False)
filter_list = []
filter_list.append(('Stages', 2))
filter_list.append(("OperatorClass", "arch::OpClassTensorOp"))
filter_list.append(("ArchTag", "arch::Sm75"))
for i in range(self.b2b_num):
filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor"))
rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True)
B2bMma_code = self.gen_B2bMma(speicalized_template_args)
epilogue_and_rest_code = self.gen_epilogue()
gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False)
code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code)))
return self.gen_include_header() + code
class gen_Kernel:
def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root):
self.gen_class_name = "B2bGemm"
self.template_param = template_param
self.b2bnum = b2b_num
self.cutlass_deps_root = cutlass_deps_root
self.project_root = project_root
def gen_include_header(self):
code = '''
#pragma once
#include \"{cutlass_dir}cutlass/cutlass.h\"
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root)
return code
def gen_Params(self):
gen_param = ""
for i in range(self.b2bnum):
gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n"
gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n"
gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n"
gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n"
for i in range(self.b2bnum):
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n"
gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n"
if i == self.b2bnum - 1:
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n"
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n"
else:
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n"
gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n"
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n"
gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n"
for i in range(self.b2bnum):
gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n"
gen_param += " " + 'int batch_count' + ";\n"
gen_param += " " + 'int gemm_k_iterations_0' + ";\n"
return gen_param
def gen_Memberfunc(self):
code_default = "\nCUTLASS_HOST_DEVICE\n"
code_default += "Params()"
code_default += " { } \n\n"
code_construct = "\nCUTLASS_HOST_DEVICE\n"
code_construct += "Params(\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n"
code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n"
code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n"
if i == self.b2bnum - 1:
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n"
else:
code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n"
code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n"
code_construct += " " + "int batch_count = 1\n"
code_construct += "):\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n"
code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n"
code_construct += " " + "params_A0(ref_A0.layout()),\n"
code_construct += " " + "ref_A0(ref_A0),\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n"
code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n"
code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n"
code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n"
code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n"
code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n"
for i in range(self.b2bnum):
code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n"
code_construct += " " + "batch_count(batch_count) {\n"
code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n"
code_construct += "}\n"
return code_default + code_construct
def gen_using(self):
code_using = ""
for i in range(self.b2bnum - 1):
code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n"
code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n"
for i in range(self.b2bnum - 1):
code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n"
code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n"
code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n"
code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc())
code_using += "union SharedStorage {\n"
code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n"
code_using += " " + "typename Epilogue::SharedStorage epilogue;\n"
code_using += "};\n"
return code_using
def gen_can_implement(self):
gen_code = ""
return gen_code
def gen_operator_and_constr(self):
ctr_code = "CUTLASS_HOST_DEVICE\n"
ctr_code += self.gen_class_name + "() { } \n\n"
operator_code = "CUTLASS_DEVICE\n"
operator_code += "void operator()(Params const &params, SharedStorage &shared_storage) {\n"
operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n"
operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n"
operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n"
operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n"
operator_code += " " + " " + "return;\n"
operator_code += " " + "}\n"
operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n"
operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n"
operator_code += " " + " " + "0\n"
operator_code += " " + "};\n"
for i in range(self.b2bnum):
operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n"
operator_code += " " + " " + "0,\n"
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n"
operator_code += " " + "};\n"
operator_code += " " + "int thread_idx = threadIdx.x;\n\n"
operator_code += " " + "MatrixCoord threadblock_offset(\n"
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n"
operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n"
operator_code += " " + ");\n"
operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n"
operator_code += " " + " " + "params.params_A0,\n"
operator_code += " " + " " + "params.ref_A0.data(),\n"
operator_code += " " + " " + "params.problem_size_0.mk(),\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + "tb_offset_A0);\n"
operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n"
for i in range (self.b2bnum):
operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n"
operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n"
operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n"
operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n"
operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n"
for i in range (self.b2bnum - 1):
operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n"
operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n"
operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n"
operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + "threadblock_offset" + ");\n"
operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n"
operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n"
for i in range (self.b2bnum - 1):
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
for i in range (self.b2bnum - 1):
operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n"
operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n"
operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n"
operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n"
operator_code += " " + "src_accum.clear();\n"
operator_code += " " + "accumulators.clear();\n"
operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, "
for i in range(self.b2bnum):
operator_code += helper.var_idx("iterator_B", i) + ", "
operator_code += "src_accum"
if self.b2bnum != 1:
operator_code += ", "
for i in range(self.b2bnum - 1):
operator_code += helper.var_idx("output_op_", i) + ", "
for i in range(self.b2bnum - 1):
operator_code += helper.var_idx("epilogue_", i) + ", "
for i in range(self.b2bnum - 1):
final = ", "
if i == self.b2bnum - 2:
final =""
operator_code += helper.var_idx("iterator_C", i) + final
operator_code += ");\n"
operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n"
operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n"
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n"
operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n"
operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n"
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + "threadblock_offset\n"
operator_code += " " + ");\n"
operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n"
operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n"
operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n"
operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n"
operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n"
operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + "threadblock_offset\n"
operator_code += " " + ");\n"
operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n"
operator_code += " " + "Epilogue epilogue(\n"
operator_code += " " + " " + "shared_storage.epilogue,\n"
operator_code += " " + " " + "thread_idx,\n"
operator_code += " " + " " + "warp_idx,\n"
operator_code += " " + " " + "lane_idx\n"
operator_code += " " + ");\n"
operator_code += " " + "epilogue("
operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", "
operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", "
operator_code += "accumulators, "
operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n"
operator_code += "}\n"
return ctr_code + operator_code
def gen_include_header(self):
code = '''
#pragma once
#include \"{cutlass_dir}cutlass/cutlass.h\"
#include \"{cutlass_dir}cutlass/gemm/gemm.h\"
#include \"{cutlass_dir}cutlass/matrix_coord.h\"
#include \"{cutlass_dir}cutlass/semaphore.h\"
'''.format(cutlass_dir=self.cutlass_deps_root)
return code
def gen_code(self):
template_param = []
template_param.append(("typename", "B2bMma"))
template_param.append(("typename", "Epilogue"))
template_param.append(("typename", "ThreadblockSwizzle"))
template_param.append((bool, "SplitKSerial"))
code_body = ""
code_body += self.gen_using()
code_body += self.gen_operator_and_constr()
struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body)
code = self.gen_include_header()
code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code)))
return self.gen_include_header() + code
class gen_kernel:
def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root):
self.template_param = template_param
self.gen_class_name = "B2bGemm"
self.gen_kernel_name = gen_class_name + "Kernel"
self.tempalte_args = []
self.cutlass_deps_root = cutlass_deps_root
self.project_root = project_root
self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root)
# Include gen_threadBlock
self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root)
self.file_dir = output_dir + "/kernel/"
def gen_code(self, first_use_1stage):
default_b2b_gemm = self.gen_default_b2b_gemm.gen_code()
print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir)
with open(self.file_dir + "default_b2b_gemm.h", "w+") as f:
f.write(default_b2b_gemm)
kernel = self.gen_Kerenl.gen_code()
print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir)
with open(self.file_dir + "b2b_gemm.h", "w+") as f:
f.write(kernel)
# Call code to gen threadblock
self.gen_threadBlock.gen_code(first_use_1stage)

View File

@ -0,0 +1,232 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import helper
import gen_ir as ir
class gen_test:
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.gen_class_name = gen_class_name
self.user_header_file = user_header_file
self.sample_dir = output_dir
self.b2b_num = len(fuse_gemm_info)
def gen_cpp_sample(self):
code = "/* Auto Generated code - Do not edit.*/\n"
code += "#include <stdio.h> \n"
code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n"
code += "#include \"cutlass/cutlass.h\" \n"
code += "#include \"../cutlass_irrelevant.h\" \n"
code += "#include \"../cutlass_verify.h\" \n"
code += "#include \"leaky_bias.h\" \n"
code += "#include \"utils.h\" \n"
code += "int main(int args, char * argv[]) {\n"
code += " " + "int M = atoi(argv[1]);\n"
code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n"
code += " " + "if(args == 3);\n"
code += " " + " " + "K0 = atoi(argv[2]);\n"
code += " " + "int B = 1;\n"
code += " " + "if(args == 4);\n"
code += " " + " " + "B = atoi(argv[3]);\n"
code += " " + "srand(1234UL);\n"
code += " " + "int device_id = 0;\n"
code += " " + "cudaGetDevice(&device_id);\n"
code += " " + "cudaDeviceProp prop;\n"
code += " " + "cudaGetDeviceProperties(&prop, device_id);\n"
code += " " + "int sm = prop.major *10 + prop.minor;\n"
code += "using ElementCompute = cutlass::half_t;\n"
for i in range(self.b2b_num):
code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n"
addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
if addbias:
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n"
else:
code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n"
code += " " + "size_t flops = 0;\n"
for i in range(self.b2b_num):
m = self.fuse_gemm_info[i]['mnk'][0]
n = self.fuse_gemm_info[i]['mnk'][1]
k = self.fuse_gemm_info[i]['mnk'][2]
bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])
this_k = "K0"
if (i > 0):
this_k = str(k)
code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n"
code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n"
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n"
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n"
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n"
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n"
code += " " + helper.var_idx("Mat_A", i) + ".init();\n"
code += " " + helper.var_idx("Mat_B", i) + ".init();\n"
code += " " + helper.var_idx("Mat_C", i) + ".init();\n"
code += " " + helper.var_idx("memory_unit<cutlass::half_t> Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n"
params = []
params.append("M")
params.append("B")
params.append("Mat_A0.device_ptr")
for i in range(self.b2b_num):
params.append(helper.var_idx("Mat_B", i) + ".device_ptr")
params.append(helper.var_idx("Mat_C", i) + ".device_ptr")
if i != self.b2b_num-1:
params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr")
params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr")
code += " " + "Param arguments = {\n"
code += " " + " " + "M,\n"
code += " " + " " + "K0,\n"
code += " " + " " + "B,\n"
code += " " + " " + "reinterpret_cast<const void*>(Mat_A0.device_ptr),\n"
cnt = 1
for i in range(self.b2b_num):
bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i])
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n"
cnt += 1
if bias_flag:
code += " " + " " + "reinterpret_cast<const void*>(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n"
cnt += 1
else:
code += " " + " " + "reinterpret_cast<const void*>(NULL),\n"
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_value = str(arg[2])
code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n"
if i != self.b2b_num - 1:
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n"
else:
code += " " + " " + "reinterpret_cast<void*>(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n"
code += " " + "TI(FUSED_CUTLASS);\n"
code += " " + "for(int i = 0; i < 100; i++){\n"
code += " " + " " + "one_api(arguments, sm, NULL);\n"
code += " " + "}\n"
code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n"
code += "\n"
for i in range(self.b2b_num):
code_this = ""
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n"
ldmA = str(self.fuse_gemm_info[i]['mnk'][2])
if i == 0:
ldmA = "K0"
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
if i == 0:
ldmB = "K0"
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
if self.fuse_gemm_info[i]['A_format'] is 'Col':
ldmA = "M"
if self.fuse_gemm_info[i]['B_format'] is 'Row':
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
if self.fuse_gemm_info[i]['C_format'] is 'Col':
ldmC = "M"
if i == 0:
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
else:
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n"
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n"
code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
arg_value = str(epilogue_arg[2])
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")"
code_this += " " + " },\n"
code_this += " " + " " + "B};\n"
code += code_this
code += " " + "TI(UNFUSED_CUTLASS);\n"
code += " " + "for(int i = 0; i < 100; i++){\n"
code += " " + " " + self.gen_class_name + "_verify(\n"
for i in range(self.b2b_num):
code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n"
code += " " + " " + " " + "NULL);\n"
code += " " + "}\n"
code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n"
code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n"
code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n"
code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \
+ helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n"
code += "\n\n}\n"
with open(self.sample_dir + "sample.cu", "w+") as f:
f.write(code)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,456 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import helper
import gen_ir as ir
class gen_turing_impl:
def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.class_name = gen_class_name
self.gen_class_name = gen_class_name + "_turing_impl"
self.user_header_file = ""
for header in user_header_file:
self.user_header_file += "#include \"" + header + "\"\n"
self.output_dir = output_dir
self.b2b_num = len(fuse_gemm_info)
self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
def gen_using(self):
code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + "<cutlass::half_t>;"
return code_using + "\n"
def gen_initialize(self):
code = ""
for i in range(self.b2b_num):
code_this = ""
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
beta = "(1)"
if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False:
beta = "(0)"
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
if i == 0:
k_str = "K0"
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
code += code_this
code += "typename b2b_gemm::Arguments arguments{\n"
for i in range(self.b2b_num):
code += " " + helper.var_idx("problem_size_", i) + ",\n"
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n"
for i in range(self.b2b_num):
ldmB = str(self.fuse_gemm_info[i]['mnk'][2])
if i == 0:
ldmB = "K0"
if self.fuse_gemm_info[i]['B_format'] is 'Row':
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n"
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n"
code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n"
for i in range(self.b2b_num):
code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
code += "},\n"
code += " " + "Batch};\n\n"
code += " " "b2b_gemm gemm_op;\n"
code += " " + "gemm_op.initialize(arguments);\n"
return code + "\n"
def gen_run(self):
code = " " + "gemm_op(stream);\n"
return code
def gen_wrapper(self):
code_body = ""
arg_lists = []
arg_lists.append(["int", "M"])
arg_lists.append(["int", "K0"])
arg_lists.append(["int", "Batch"])
arg_lists.append(["void*", helper.var_idx("A", 0)])
for i in range(self.b2b_num):
arg_lists.append(["void*", helper.var_idx("B", i)])
arg_lists.append(["void*", helper.var_idx("C", i)])
arg_lists.append(["void*", helper.var_idx("D", i)])
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_tp = arg[0]
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
arg_lists.append([arg_tp, arg_name])
if self.b2b_num == 1:
code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta
code_body += self.gen_turing_unfused.gen_initialize()
code_body += self.gen_turing_unfused.gen_run()
else:
code_body += self.gen_using()
code_body += self.gen_initialize()
code_body += self.gen_run()
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
return code
def gen_code(self):
code = self.gen_wrapper()
helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code)
class gen_volta_turing_fuse_act_impl:
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.gen_class_name = gen_class_name + "_volta_impl"
self.user_header_file = ""
for header in user_header_file:
self.user_header_file += "#include \"" + header + "\"\n"
self.output_dir = output_dir
self.b2b_num = len(fuse_gemm_info)
def perf_tiling(self, layer_mnk):
mnk = layer_mnk[:]
block_tile = mnk[:]
block_tile[2] = 32 # force the K tile to be 32
# M tile gen
block_tile[0] = 32
# N tile gen
if mnk[1] > 128:
block_tile[1] = 256
elif mnk[1] > 64:
block_tile[1] = 128
elif mnk[1] > 32:
block_tile[1] = 64
else :
block_tile[1] = 32
warp_tile = block_tile[:]
if block_tile[1] == 256:
warp_tile[1] = 64
elif block_tile[1] == 128:
warp_tile[1] = 32
elif block_tile[1] == 64:
warp_tile[1] = 32
else :
warp_tile[1] = 32
warp_tile[0] = 32
return block_tile, warp_tile
def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp):
epilogue_setted_type = epilogue_tp
cutlass_epilogue_name = "LinearCombinationRelu"
if epilogue_setted_type.lower() == 'leakyrelu':
cutlass_epilogue_name = "LinearCombinationLeakyRelu"
elif epilogue_setted_type.lower() == 'identity':
cutlass_epilogue_name = "LinearCombination"
n_mod_8 = n % 4
N_align_elements = 1
if n_mod_8 == 0:
N_align_elements = 8
elif n_mod_8 == 4:
N_align_elements = 4
elif n_mod_8 == 2 or n_mod_8 == 6:
N_align_elements = 2
epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">"
return epilogue_str
def gen_using(self, volta = True):
code_using = ""
volta_arch = "cutlass::arch::Sm70"
volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>"
turing_arch = "cutlass::arch::Sm75"
turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>"
arch = ""
tc = ""
if volta:
arch = volta_arch
tc = volta_tc
else:
arch = turing_arch
tc = turing_tc
for i in range(self.b2b_num):
k = self.fuse_gemm_info[i]['mnk'][2]
k_mod_8 = k % 4
ab_ldm = 1
if k_mod_8 == 0:
ab_ldm = 8
elif k_mod_8 == 4:
ab_ldm = 4
elif k_mod_8 == 2 or k_mod_8 == 6:
ab_ldm = 2
block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk'])
this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n"
this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n"
this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n"
this_gemm_config += " " + arch + ",\n"
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n"
this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n"
this_gemm_config += " " + tc + ",\n"
this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n"
this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n"
this_gemm_config += " " + "2,\n"
this_gemm_config += " " + str(ab_ldm) + ",\n"
this_gemm_config += " " + str(ab_ldm) + ">;\n"
code_using += this_gemm_config + "\n"
return code_using + "\n"
def gen_initialize(self):
code = ""
for i in range(self.b2b_num):
code_this = ""
N_str = str(self.fuse_gemm_info[i]['mnk'][1])
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n"
beta = "(1)"
if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False:
beta = "(0)"
code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n"
k_str = str(self.fuse_gemm_info[i]['mnk'][2])
if i == 0:
k_str = "K0"
code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n"
code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n"
code_this += " " + helper.var_idx("problem_size_", i) + ",\n"
ldmA = k_str
ldmB = k_str
ldmC = str(self.fuse_gemm_info[i]['mnk'][1])
ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i]))
if self.fuse_gemm_info[i]['A_format'] is 'Col':
ldmA = "M"
if self.fuse_gemm_info[i]['B_format'] is 'Row':
ldmB = str(self.fuse_gemm_info[i]['mnk'][1])
if self.fuse_gemm_info[i]['C_format'] is 'Col':
ldmC = "M"
if i == 0:
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
else:
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n"
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n"
M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0])
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n"
code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n"
code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i)
for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]):
arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1]
code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")"
code_this += " },\n"
code_this += " " + "Batch};\n"
code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n"
code += code_this + "\n"
return code + "\n"
def gen_run(self):
code = ""
for i in range(self.b2b_num):
code_this = ""
code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n"
code += code_this
return code
def gen_wrapper(self):
code_body = ""
arg_lists = []
arg_lists.append(["int", "M"])
arg_lists.append(["int", "K0"])
arg_lists.append(["int", "Batch"])
arg_lists.append(["void*", helper.var_idx("A", 0)])
for i in range(self.b2b_num):
arg_lists.append(["void*", helper.var_idx("B", i)])
arg_lists.append(["void*", helper.var_idx("C", i)])
arg_lists.append(["void*", helper.var_idx("D", i)])
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_tp = arg[0]
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
arg_lists.append([arg_tp, arg_name])
code_body += self.gen_using()
code_body += self.gen_initialize()
code_body += self.gen_run()
code = ir.gen_func(self.gen_class_name, arg_lists, code_body)
return code
def gen_code(self):
code = self.gen_wrapper()
helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code)
class gen_one_API:
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.gen_class_name = gen_class_name
self.user_header_file = ""
for header in user_header_file:
self.user_header_file += "#include \"" + header + "\"\n"
self.output_dir = output_dir
self.b2b_num = len(fuse_gemm_info)
self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
def gen_CUTLASS_irrelevant_API(self):
code = ""
code += "#include <cuda_runtime.h>\n"
code += "#include <assert.h>\n"
param_name = "Fused" + str(self.b2b_num) + "xGemm_"
for i in range(self.b2b_num):
param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_"
param_name += "Params"
params = ""
params += " " + "int M;\n"
params += " " + "int K0;\n"
params += " " + "int Batch;\n"
params += " " + "const void* A0;\n"
for i in range(self.b2b_num):
params += " " + "const void* " + helper.var_idx("B", i) + ";\n"
params += " " + "const void* " + helper.var_idx("C", i) + ";\n"
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_tp = arg[0]
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
params += " " + arg_tp + " " + arg_name + ";\n"
params += " " + "void* " + helper.var_idx("D", i) + ";\n"
code += ir.gen_struct(param_name, params)
code += "using Param = " + param_name + ";\n"
code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n"
return code
def gen_one_api(self):
code = ""
code += "/* Auto Generated code - Do not edit.*/\n"
code += "#include \"cutlass_irrelevant.h\"\n"
code += "#include \"api.h\"\n"
code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n"
code += " " + "if (sm == 70) \n"
code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
for i in range(self.b2b_num):
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
code += helper.var_idx("param.D", i) + ", "
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
code += "param." + arg_name + ", "
code += "stream);\n"
code += " " + "else if(sm >= 75) \n"
code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast<void*>(param.A0), "
for i in range(self.b2b_num):
code += helper.var_idx("const_cast<void*>(param.B", i) + "), "
code += helper.var_idx("const_cast<void*>(param.C", i) + "), "
code += helper.var_idx("param.D", i) + ", "
epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i])
for arg in epilogue_args:
arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1]
code += "param." + arg_name + ", "
code += "stream);\n"
code += " " + "else assert(0);\n"
code += "}\n"
return code
def gen_code(self):
turing_code = self.gen_turing.gen_wrapper()
volta_code = self.gen_volta.gen_wrapper()
cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API()
one_api_code = self.gen_one_api()
with open(self.output_dir + "one_api.cu", "w+") as f:
f.write(one_api_code)
helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code)
helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code)

View File

@ -0,0 +1,92 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import helper
import gen_ir as ir
import gen_turing_and_volta as gen_basic
class gen_verify:
def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"):
self.fuse_gemm_info = fuse_gemm_info
self.name = gen_class_name + "_verify"
self.b2b_num = len(fuse_gemm_info)
self.params = []
self.user_header_file = ""
for header in user_header_file:
self.user_header_file += "#include \"" + header + "\"\n"
self.seperate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir)
self.gen_params()
self.output_dir = output_dir
def gen_code(self):
code = ""
code += self.user_header_file
code += self.seperate_cutlass.gen_using(False) #False -> Turing, True -> Volta
code_body = ""
for i in range(self.b2b_num):
code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n"
code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n"
code_body += self.seperate_cutlass.gen_run()
code += ir.gen_func(self.name, self.params, code_body)
helper.write_2_headfile("cutlass_verify.h", self.output_dir, code)
def gen_params(self):
for i in range(self.b2b_num):
self.params.append(
(
helper.var_idx("typename Gemm", i)+ "::Arguments",
helper.var_idx("Arguments_", i)
)
)
def get_params(self, declartion = True):
code = ""
if declartion:
for param in self.params:
code += param[0] + " " + param[1] + ";\n"
return code
def gen_initialize():
code = ""
initialize_code = self.seperate_cutlass.gen_initialize()
code = ir.gen_func("initialize", [[]])

View File

@ -0,0 +1,52 @@
#!/bin/bash
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
NUM_ARGS=3
if [ $# -ne $NUM_ARGS ]; then
echo "Usage: $0 <config_file> <output_directory> <cutlass_directory>"
echo " config_file: JSON file containing configuration to run"
echo " output_directory: directory to store results"
echo " cutlass_directory: directory containing cutlass source"
exit 1
fi
config_file=$1
output_dir=$2
cutlass_dir=$3
python3 gen_all_code.py \
--config-file $config_file \
--gen-name FusedMultiGemmForward \
--output-dir $output_dir \
--cutlass-dir $cutlass_dir

View File

@ -0,0 +1,135 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
def type_2_cutlass_type(input_type = "fp16"):
# float point type
if input_type == "fp32":
return "float"
if input_type == "bf16":
return "cutlass::bfloat16_t"
if input_type == "fp16":
return "cutlass::half_t"
# integer type
if(input_type == "int32"):
return "int32_t"
if(input_type == "int8"):
return "int8_t"
if input_type == 'Row':
return 'cutlass::layout::RowMajor'
if input_type == 'Col':
return 'cutlass::layout::ColumnMajor'
def cvt_2_cutlass_shape(gemm_shape):
# gemm shape
if len(gemm_shape) == 3:
val = "cutlass::gemm::GemmShape<" \
+ str(gemm_shape[0]) + ", " \
+ str(gemm_shape[1]) + ", " \
+ str(gemm_shape[2]) + ">"
return val
def write_2_headfile(filename, file_dir, string):
with open(file_dir + filename, 'w') as f:
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
def var_idx(varaiable, index):
return varaiable + str(index)
def list_2_string(input_list, ):
rtn_string = ""
cnt = 0
for element in input_list:
final = ", \n"
if cnt == len(input_list) - 1:
final = "\n"
cnt += 1
rtn_string += str(element) + final
return rtn_string
def get_epilouge_info(layer_info):
return layer_info['epilogue']
def get_epilogue_tp(layer_info):
epilogue_info = get_epilouge_info(layer_info)
return epilogue_info['tp']
def get_epilogue_add_bias_or_not(layer_info):
epilogue_info = get_epilouge_info(layer_info)
return epilogue_info['bias']['addbias']
def get_epilogue_add_bias_tp(layer_info):
epilogue_info = get_epilouge_info(layer_info)
return epilogue_info['bias']['bias_tp']
def get_epilogue_args(layer_info):
epilogue_info = get_epilouge_info(layer_info)
return epilogue_info['args']
def get_epilogue_bias_shape(layer_info):
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
mn_shape = layer_info['mnk'][:-1]
if bias_tp == 'mat':
mn_shape[0] = 'M'
return mn_shape
elif bias_tp == 'vec':
mn_shape[0] = 1
return mn_shape
else:
assert(0)
def get_epilogue_bias_ldm(layer_info):
bias_tp = get_epilogue_add_bias_tp(layer_info).lower()
mn_shape = layer_info['mnk'][:-1]
c_layout = layer_info['C_format'].lower()
if c_layout != 'row':
assert(0)
if bias_tp == 'mat':
return mn_shape[1]
elif bias_tp == 'vec':
return 0
else:
assert(0)
def get_epilogue_compute_tp(layer_info):
return layer_info['Acc_tp']

View File

@ -0,0 +1,67 @@
#################################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import os
class replace_fix_impl:
def __init__(self, src_dir, dst_dir, cutlass_deps_root):
self.src_dir = src_dir
self.dst_dir = dst_dir
self.cutlass_deps_root = cutlass_deps_root
def gen_code(self):
for sub_dir in os.walk(self.src_dir):
files_in_sub_dir = sub_dir[2]
src_dirs = sub_dir[0]
output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):]
if not os.path.exists(output_dirs):
os.mkdir(output_dirs)
for f in files_in_sub_dir:
with open(src_dirs +"/" + f, 'r') as current_file:
output_lines = []
lines = current_file.readlines()
for line in lines:
if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"):
new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):]
# print(new_line)
output_lines.append(new_line)
else:
output_lines.append(line)
with open(output_dirs + "/" + f, "w+") as dest_file:
dest_file.writelines(output_lines)

View File

@ -0,0 +1,292 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuda_fp16.h>
template <typename T>
__device__
T add(T const & a, T const &b){
return (a + b);
}
template <>
__device__
half2 add(half2 const & a, half2 const &b){
return (__hadd2(a,b));
}
template <typename T>
struct RELU{
__device__
T operator()(T const & a){
return a > T(0) ? a : T(0);
}
__device__
half2 operator()(half2 const & a){
float2 a_fp32x2 = __half22float2(a);
a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f;
a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f;
if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f)
printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y);
return __float22half2_rn(a_fp32x2);
}
};
template <typename T>
struct LEAKY_RELU{
__device__
T operator()(T const & a, T const & scale = half(1)){
return a > T(0) ? a : scale * a;
}
__device__
half2 operator()(half2 const & a, half const & scale = half(1)){
half2 zero = __half2half2(half(0));
half2 gt_zero = __hge2(a, zero);
half2 le_zero = __hle2(a, zero);
half2 scale_f16x2 = __half2half2(scale);
half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero);
return __hmul2(a, mask_scale_f16x2);
}
};
template <int N, int BLOCKDIM>
__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
LEAKY_RELU<half> Act;
Access_tp src_v[iter];
Access_tp bias_v[iter];
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
if (mat_bias)
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
else
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale);
}
}
}
template <int N, int BLOCKDIM>
__global__ void leaky_and_activation(half* inout, half scale){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
LEAKY_RELU<half> Act;
Access_tp src_v[iter];
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale);
}
}
}
template <int N, int BLOCKDIM>
void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){
dim3 grid(m, b);
if (bias == nullptr)
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, scale);
else
leaky_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, scale, mat_bias);
}
template <int N, int BLOCKDIM>
__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
RELU<half> Act;
Access_tp src_v[iter];
Access_tp bias_v[iter];
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
if (mat_bias)
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
else
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]));
}
}
}
template <int N, int BLOCKDIM>
__global__ void relu_and_activation(half* inout){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
RELU<half> Act;
Access_tp src_v[iter];
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]);
}
}
}
template <int N, int BLOCKDIM>
void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
dim3 grid(m, b);
if (bias == nullptr)
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
else
relu_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
}
template <int N, int BLOCKDIM>
__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
Access_tp src_v[iter];
Access_tp bias_v[iter];
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
if (mat_bias)
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + blockIdx.x * N + idx + batch_offset);
else
bias_v[i] = *reinterpret_cast<Access_tp*>(bias + idx + batch_id * N);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i]));
}
}
}
template <int N, int BLOCKDIM>
__global__ void identity_and_activation(half* inout){
constexpr bool N_MOD_2 = N & 1 ? false : true;
using Access_tp = typename std::conditional<N_MOD_2, half2, half>::type;
constexpr int Access_elements = sizeof(Access_tp) / sizeof(half);
constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements);
int batch_id = blockIdx.y;
int batch_offset = batch_id * gridDim.x * N;
Access_tp src_v[iter];
for(int i = 0; i < iter; i++){
int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements;
if (idx < N){
src_v[i] = *reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset);
*reinterpret_cast<Access_tp*>(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]);
}
}
}
template <int N, int BLOCKDIM>
void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){
dim3 grid(m, b);
if (bias == nullptr)
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout);
else
identity_and_activation<N, BLOCKDIM><<<grid , BLOCKDIM>>>(inout, bias, mat_bias);
}

View File

@ -0,0 +1,94 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#define TI(tag) \
cudaEvent_t _event_start_ ##tag; \
cudaEvent_t _event_end_ ##tag; \
float _event_time_ ##tag; \
cudaEventCreate(& _event_start_ ##tag); \
cudaEventCreate(& _event_end_ ##tag); \
cudaEventRecord(_event_start_ ##tag);
#define TO(tag, str, times) \
cudaEventRecord(_event_end_ ##tag); \
cudaEventSynchronize(_event_end_ ##tag); \
cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \
float _event_time_once_ ##tag = _event_time_ ##tag / times; \
printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \
cudaDeviceSynchronize(); \
printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError()));
template<typename T>
struct memory_unit{
T* host_ptr;
T* device_ptr;
int size_bytes;
int elements;
void h2d(){
cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice);
}
void d2h(){
cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost);
}
void free_all(){
free(host_ptr);
cudaFree(device_ptr);
}
memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){
host_ptr = (T*) malloc(elements_ * sizeof(T));
cudaMalloc((void**)&device_ptr, elements_ * sizeof(T));
}
void init(int abs_range = 1){
for(int i = 0; i < elements; i++){
host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range);
}
h2d();
}
};
template<typename T>
int check_result(T * a, T * b, int N){
int cnt = 0;
for(int i = 0; i < N; i ++){
float std = float(a[i]);
float my = float(b[i]);
if(abs(std - my) / abs(std) > 1e-2)
{
// printf("my: %f , std: %f\n", my, std);
cnt++;
}
}
printf("total err: %d / %d\n", cnt, N);
return cnt;
}

View File

@ -30,7 +30,7 @@
cutlass_example_add_executable(
43_dual_gemm
45_dual_gemm
dual_gemm.cu
)

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
46_depthwise_simt_conv2dfprop
depthwise_simt_conv2dfprop.cu
)

View File

@ -0,0 +1,672 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run depthwise 2d convolution kernels using functions and data structures
provided by CUTLASS using SIMT instruction;
There are 3 types of implementations of depthwise 2d convoltion
1. kAnalytic
Implicit gemm 2d convoltion algorithm.
2. kOptimized
An optimized algorithm and supports arbitrary stride and dilation.
3. kFixedStrideDilation
An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do
more optimizations.
In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter
size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If
in this case, please use kOptimized.
For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve
better perf, when the output tensor size is large, splitk should be enabled to achieve better perf.
In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d
convolution kernel.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = cutlass::half_t; // Data type of accumulator
using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm60;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 64;
// This code section describes the output tile <N, O, P, Q> a thread block will compute
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = cutlass::MatrixShape<3, 3>;
// Threadblock tile shape
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
// WarpShape::kM = P * Q the warps would process
// WarpShape::kN = groups_per_cta that the warps would process
// WarpShape::kK = filter_size that the warps would process
using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is kFixedStrideDilation
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = cutlass::MatrixShape<1, 1>;
using DilationShape = cutlass::MatrixShape<1, 1>;
constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation.
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::Tensor4DCoord input_size;
cutlass::Tensor4DCoord filter_size;
cutlass::Tensor4DCoord padding;
cutlass::MatrixCoord conv_stride;
cutlass::MatrixCoord dilation;
int groups;
int splitk;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
std::string tag;
Options()
: help(false),
input_size(1, 128, 128, 32),
filter_size(32, 3, 3, 1),
groups(32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(false),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
splitk(1) {}
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
bool valid() {
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// depthwise conv
if (groups != input_size.c()) {
return false;
}
if (filter_size.n() != groups) {
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
cmd.get_cmd_line_argument("g", groups);
filter_size.c() = 1;
filter_size.n() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("splitk", splitk);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
int32_t padding_h = filter_size.h() / 2;
int32_t padding_w = filter_size.w() / 2;
padding = {padding_h, padding_h, padding_w, padding_w};
}
/// Prints the usage statement.
std::ostream &print_usage(std::ostream &out) const {
out << "41_depthwise_gemm_fprop example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --g=<int> Groups\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --splitk=<int> Enable splitK\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results "
"table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 "
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 "
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
cutlass::Tensor4DCoord output_size() const {
return cutlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas =
output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cutlass::Status reference_check;
cudaError_t error;
Result()
: runtime_ms(0),
gflops(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) {}
static std::ostream &print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs";
return out;
}
std::ostream &print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
cutlass::Tensor4DCoord output_size = options.output_size();
out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << ","
<< options.input_size.w() << "," << options.input_size.c() << ","
<< options.filter_size.n() << "," << options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column()
<< ","
<< options.dilation.row() << "," << options.dilation.column() << ","
<< options.splitk << ","
<< runtime_ms << "," << gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one testcase
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the CUTLASS Utilities.
//
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b_transpose(options.filter_size);
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0);
// Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0);
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(tensor_d.host_view());
// Fill tensor D for reference on host with zeros
cutlass::reference::host::TensorFill(tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_b_transpose.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for CUTLASS Convolution
//
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
// Split P*Q into multiple CTA
int split_k_slices = options.splitk;
// Construct Conv2dProblemSize with user defined output size
cutlass::conv::Conv2dProblemSize problem_size(options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices,
options.groups);
// Construct Direc2dConv::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename Direct2dConv::Arguments arguments{problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
tensor_b_transpose.device_ref()};
//
// Initialize CUTLASS Convolution
//
Direct2dConv implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(result.status);
//
// Launch initialized CUTLASS kernel
//
result.status = implicit_gemm_op();
CUTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
cutlass::reference::host::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue> >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
bool passed =
cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view());
if (!passed) {
result.reference_check = cutlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
} else {
result.reference_check = cutlass::Status::kSuccess;
std::cout << "Passed.\n";
}
} else {
result.reference_check = cutlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";
std::ofstream output_workspace(ss.str());
output_workspace << "Input = \n"
<< tensor_a.host_view() << "\n\n"
<< "Filters = \n"
<< tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto &event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
CUTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error)
<< std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
bool notSupported = false;
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major >= 6)) {
std::cerr << "Run on a machine with compute capability at least 60." << std::endl;
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -119,9 +119,11 @@ foreach(EXAMPLE
37_gemm_layernorm_gemm_fusion
38_syr2k_grouped
39_gemm_permute
41_multi_head_attention
42_fused_multi_head_attention
43_dual_gemm
41_fused_multi_head_attention
42_ampere_tensorop_group_conv
43_ell_block_sparse_gemm
45_dual_gemm
46_depthwise_simt_conv2dfprop
)
add_subdirectory(${EXAMPLE})

View File

@ -80,9 +80,9 @@ public:
typedef value_type *pointer;
typedef value_type const * const_pointer;
using ArrayType = Array<T, N>;
using reference = typename ArrayType::reference;
using const_reference = typename ArrayType::const_reference;
using Array = Array<T, N>;
using reference = typename Array::reference;
using const_reference = typename Array::const_reference;
public:

View File

@ -85,6 +85,10 @@ struct Sm86 {
static int const kMinComputeCapability = 86;
};
struct Sm90 {
static int const kMinComputeCapability = 90;
};
/// Triggers a breakpoint on the device
CUTLASS_DEVICE
void device_breakpoint() {

View File

@ -451,7 +451,7 @@ template <>
CUTLASS_DEVICE
void shared_store<16>(uint32_t ptr, void const *src) {
uint4 const *dst_u128 = reinterpret_cast<uint4 const *>(src);
asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
: :
"r"(ptr),
"r"(dst_u128->x),

View File

@ -223,4 +223,6 @@ struct SparseMma;
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sparse_sm80.h"
#include "cutlass/arch/mma_sm90.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1065,7 +1065,7 @@ struct Mma<
int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d);
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
asm volatile("_mma.m8n8k32.row.col.u4.s4.sat {%0,%1}, %2, %3, {%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
@ -1247,7 +1247,8 @@ struct Mma<
) const {
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED))
using WmmaFragmentA = nvcuda::wmma::fragment<
nvcuda::wmma::matrix_a,
Shape::kM,
@ -1279,6 +1280,7 @@ struct Mma<
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
#else
CUTLASS_UNUSED(a);
@ -1289,14 +1291,7 @@ struct Mma<
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
};

View File

@ -2156,6 +2156,7 @@ struct Mma<
int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d);
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, "

View File

@ -0,0 +1,131 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Matrix multiply
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "mma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8))
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define CUTLASS_ARCH_MMA_SM90_ENABLED
#endif
#endif
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////
/// Matrix Multiply-Add 16x8x4 fp64
////////////////////////////////////////////////////////////////////////////////
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
template <>
struct Mma<
gemm::GemmShape<16,8,4>,
32,
double,
layout::RowMajor,
double,
layout::ColumnMajor,
double,
layout::RowMajor,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<16,8,4>;
using ElementA = double;
using LayoutA = layout::RowMajor;
using FragmentA = Array<double, 2>;
using ElementB = double;
using LayoutB = layout::ColumnMajor;
using FragmentB = Array<double, 1>;
using ElementC = double;
using LayoutC = layout::RowMajor;
using FragmentC = Array<double, 4>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm90;
CUTLASS_HOST_DEVICE
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
FragmentC const &c) const {
#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED)
double const *A = reinterpret_cast<double const *>(&a);
double const *B = reinterpret_cast<double const *>(&b);
double const *C = reinterpret_cast<double const *>(&c);
double *D = reinterpret_cast<double *>(&d);
asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
: "d"(A[0]), "d"(A[1]),
"d"(B[0]),
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
#else
CUTLASS_UNUSED(d);
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

201
include/cutlass/barrier.h Normal file
View File

@ -0,0 +1,201 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
struct Barrier
{
public:
/// Flag type
using T = int;
/// Initial flag value
static const T INIT = 0;
protected:
/// Load flag, as a strong operation (int specialization)
CUTLASS_DEVICE
static int ld_strong(int *ptr)
{
int state = 0;
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#else
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#endif // (__CUDA_ARCH__ >= 700)
return state;
}
/// Store flag, as a strong operation (int specialization)
CUTLASS_DEVICE
static void st_strong(int *ptr, int val)
{
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#endif // (__CUDA_ARCH__ >= 700)
}
/// Reduce into flag, with release pattern (int specialization)
CUTLASS_DEVICE
static void red_release(int *ptr, int val)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicAdd(ptr, val);
#endif // (__CUDA_ARCH__ >= 700)
#endif
}
public:
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(ld_strong(flag_ptr) < count) {}
}
__syncthreads();
#endif
}
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(ld_strong(flag_ptr) != val) {}
}
__syncthreads();
#endif
}
/// Uses thread[0] to wait for the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(atomicCAS(flag_ptr, val, 0) != val) {}
}
__syncthreads();
#endif
}
/// Increment the arrival count for a flag
CUTLASS_DEVICE
static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
__syncthreads();
if (thread_idx == 0) {
red_release(flag_ptr, 1);
}
#endif
}
/// Increment the arrival counts for a range of flags
CUTLASS_DEVICE
static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
int flag_idx = first_flag_idx + thread_idx;
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
// Barrier to make sure all other threads in block have written their data
__syncthreads();
// Select threads increment their flags
if (thread_idx < count) {
red_release(flag_ptr, 1);
}
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

Some files were not shown because too many files have changed in this diff Show More