Compare commits

...

138 Commits

Author SHA1 Message Date
bbe579a9e3 Updates for CUTLASS 3.4.1 (#1346)
* Updates for CUTLASS 3.4.1

* minor epi change
2024-02-15 15:48:34 -05:00
47a3ebbea9 Add a missing platform include (#1328) 2024-02-03 01:30:32 -05:00
57e01e1a6b Fix missing include file (#1318) 2024-02-03 01:29:32 -05:00
6e3df975a2 Modify comments in code examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu (#1325) 2024-01-31 21:41:30 -05:00
8825fbf1ef fix unrecognized print format specifier for int8/uint8 (#1303)
* fix unrecognized print format specifier for int8/uint8

* use c++ static_cast instead of c cast style
2024-01-29 21:22:40 -05:00
092f14db05 fix tile_size_mnk compilation warning (#1294) 2024-01-29 21:21:15 -05:00
9385141f19 Update PUBLICATIONS.md
ptq paper from goog
2024-01-19 14:17:55 -05:00
b4b5b11070 Update PUBLICATIONS.md
add odyssey llm paper from metuan
2024-01-18 10:30:21 -05:00
139b93db61 update publications (#1308) 2024-01-17 14:06:46 -05:00
ca37d632c9 Remove sparse GEMM with row broadcasted bias vector (#1302)
This reverts commit d3e72719b4.

Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs>
2024-01-17 14:06:27 -05:00
362abbf274 Support ElementD to be void for tma (#1153)
* Support void D with AuxStore

* refine get_element_aux
2024-01-16 18:15:42 -05:00
751eb9a885 Update license year (#1306) 2024-01-16 14:37:22 -05:00
2f589ffa76 Updates for 3.4 release. (#1305) 2024-01-16 13:42:51 -05:00
acba5beee5 Fix flops calculation and tensor b stride calculation in the example 36 (#1278)
* Fix flops calculation and tensor b stride calculation in the example 36

* Fix datatype

* Update gather_scatter_fusion.cu
2024-01-08 17:27:30 -05:00
74d1f3e63a Fix cute::array<T, 0> iterator (#1273) 2024-01-08 17:10:09 -05:00
8ac2edc810 expose stream API in python kernel call interfaces (#1287)
* expose stream API in python kernel call interfaces

* add stream to ReductionArguments; document stream arg

* add stream argument to GemmGroupedArguments
2024-01-05 08:27:45 -05:00
d4be5ab5d7 Allow per-column bias in EpilogueTensorBroadcast (#1275)
* Allow per-column bias in EpilogueTensorBroadcast

EpilogueTensorBroadcast only supports per-row vector broadcast, because
the bias stride is hardcoded.

It can easily support both if the bias stride is made conditional, and
the original behavior is maintained by defaulting to per-row.

* Add unit test for EpilogueTensorBroadcast with per-col bias

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Ali Hassani <ali@hippoml.com>
2024-01-04 12:48:31 -05:00
c9591a694d fix typo (#1279) 2024-01-04 12:41:39 -05:00
5c756eb774 Add support for sparse GEMM with visitor epilogue (#1189)
* Add support for sparse GEMM with visitor epilogue

* Refactor changes at the kernel level
2024-01-04 12:38:11 -05:00
8236f30675 CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
2023-12-29 15:21:31 -05:00
b7508e3379 Fix inline ptx escaping for predicates. (#1264)
* Fix inline ptx escaping for predicates.

Prevents `error: invalid % escape in inline assembly string` when compiling with clang.

* More double-quoting.
2023-12-14 11:16:15 -05:00
f60786b536 Remove undefined behavior from default constructor of PredicatedTileAccessIteratorParams. (#1258)
Currently, the default constructor of
`PredicatedTileAccessIteratorParams` will invoke undefined behavior in
its invocation of the `initialize` function. Specifically, it will
attempt to read from the uninitialized variables
`desc.element_size_bits` and `desc.advance_rank`. This commit changes
the default constructors of both `*Params` and `*Desc` to
zero-initialize all uninitialized members.
2023-12-11 23:01:53 -05:00
30ec1a4649 Use size_t index to iterate up to std::vector::size() (#1251)
Fixes a different signedness compare warning.
2023-12-09 08:44:31 -05:00
e1483d5fa0 Collection of changes to fix clang build. (#1200)
* Remove unused variables

* Qualify calls to make_fragment_? from templated base class.

Fixes clang build error.

* Add missing `#include <cstdio>`

* Various changes to fix clang compile errors.

* More changes to fix clang build.

Remaining issues:

- `params` initializer of `CollectiveEpilogue`.
- `ops` initializer of `Sm90VisitorImplBase`.
- `__usAtomicCAS` needs to be added to clang upstream.

* Fix remaining clang build issues.

* Qualify `cute::rank()` calls.

* Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace.

* Double-escape special registers in inline asm.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-12-08 14:42:12 -05:00
f4a0216601 Fix bug in single source GEMM with residual + streamk (#1249)
Followup to #1224.

A change in the stream-k threadblock swizzle ctor since 3.3 breaks
single source GEMM with fused epilogue and stream-k. Multi-source was
already corrected.

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
2023-12-07 11:12:02 -05:00
f188f9b709 Fix typo in quickstart.md (#1257) 2023-12-07 09:49:52 -05:00
9c9b51d35c Update PUBLICATIONS.md 2023-12-07 00:02:36 -05:00
a75b4ac483 Fix Stream-K reduce bug in epilogue with broadcast (#1224)
Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
2023-12-05 15:35:41 -05:00
e9e30c2304 Updates and Bug fixes to CUTLASS 3.3 (#1232) 2023-12-05 09:50:49 -05:00
4a1709e17e Fixed illegal PTX syntax (#1225) 2023-12-01 12:29:48 -05:00
bef1fbcbe6 Add missing #include <cstdio> (#1197)
* Add missing `#include <cstdio>`

* move to non nvrtc part

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-12-01 11:58:53 -05:00
2375a07d01 Qualify calls to make_fragment_? from templated base class. (#1196)
Fixes clang build error.
2023-12-01 09:52:57 -05:00
60c8251b72 Remove unused variables (#1195) 2023-12-01 09:52:19 -05:00
10b850f9c7 Fix some sign conversion warnings (#1172)
* Fix sign conversion warnings

* Fix type conversion warnings

* Fix sign conversion warnings

* Change smem_size_ to constexpr

* clang warnings

* undo cast change

* one miss change

* missing part

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-11-30 00:28:40 -05:00
99c4eebe3b Explicitly cast blockIdx to uint3 (#1192)
This works around a clang issue where blockIdx is of a different type.
2023-11-30 00:26:23 -05:00
a759e85f5f Add subclass declarations to generated files. (#1193) 2023-11-30 00:25:40 -05:00
56fc3df03b Adding missing typename (#1191)
Fixes clang build failures.
2023-11-29 00:20:20 -05:00
eb01d5449d fix cp.async L2 prefetch typo (#1187) 2023-11-28 16:58:04 -05:00
8098336d51 Updates to Python interface for PyPI packaging (#1209)
* Updates

* Updates to notebooks
2023-11-28 13:52:12 -05:00
b5d8a5d9cc Allow SM90 pingpong kernel to use custom tile schedulers (#1194)
Co-authored-by: Sergey Klevtsov <sklevtsov@nvidia.com>
2023-11-15 13:45:17 -05:00
6e60b9b17c enable L2::128B prefetch for cp.async by default (#1177) 2023-11-13 13:30:13 -05:00
1ab6cc7b68 Fix std::abs overloading for bfloat16_t (#1179) 2023-11-13 13:29:45 -05:00
5ae8133cfa Doc only change changelog 3.3 (#1180) 2023-11-13 13:29:22 -05:00
39c6a83f23 fix missing return warning (#1173) 2023-11-03 22:42:59 -04:00
1d7f2a207e Fix several broken links (#1168)
Co-authored-by: isaacw <isaacw@nvidia.com>
2023-11-03 00:01:25 -04:00
557be3ab0e Fix several typos (#1169)
Co-authored-by: isaacw <isaacw@nvidia.com>
2023-11-02 23:54:46 -04:00
c008b4aea8 CUTLASS 3.3.0 (#1167)
* Release 3.3.0

Adds support for mixed precision GEMMs On Hopper and Ampere
Adds support for < 16B aligned GEMMs on Hopper
Enhancements to EVT
Enhancements to Python interface
Enhancements to Sub-byte type handling in CuTe
Several other bug-fixes and performance improvements.

* minor doc update
2023-11-02 11:09:05 -04:00
922fb5108b clean the format (#1140) 2023-10-24 22:59:06 -04:00
7a7796afae Fix is_zero (#1147)
* Fix is_zero

* Use constexpr

* Add CUTLASS_PRAGMA_UNROLL to loops

* Avoid if branches in is_zero
2023-10-23 12:09:37 -04:00
fb10fa5308 Fix broken pipeline link in docs (#1143) 2023-10-18 12:55:46 -04:00
5e1a0a5adb fix alignmentC for h16816_s8xf16 (#1146)
* fix alignmentC for h16816_s8xf16

* manish's change
2023-10-17 15:15:39 -04:00
757275f279 Adding more Threadblock Tiles for Mixed-input TensorOp (BF16 * S8) in cutlass_library (#1132)
* Adding more tiles in the cutlass_library for mixed-input support.

* fix rebase issue

* more tiles to upcast a
2023-10-13 11:33:15 -04:00
fa8dfe631f fix missing return warning for repeat and axpby (#1124) 2023-10-12 00:05:45 -04:00
112590114d Add config.yml issue template with Discord link. (#1135) 2023-10-10 12:13:04 -04:00
ff02da2667 Fx parallel split-k (#1116) 2023-10-06 12:02:40 -04:00
4082fed85a Add missing int64 and uint64 overloads for conj (#1127)
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
2023-10-05 20:01:44 -04:00
5f13dcad78 set kIsHeavy member variables (#1012)
* set kIsHeavy member variables

* correct kIsHeavy value for Tanh

* set kIsHeavy=false for HardSwish

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-10-04 12:38:36 -04:00
61a38f83dc Add #include <limits> to platform.h (#1121)
Closes #1118
2023-10-02 21:41:25 -04:00
ff61a49dd1 Allow changing epsilon parameter in RMS norm kernel (#1112) 2023-10-02 20:40:28 -04:00
26986bbc60 Fix type typo in rmsnorm (#1119)
Initially the variable `h4` is `half4`, but its last two fields are not used. Based on the semantics and the context, I believe it should be `half2`.
2023-10-02 20:40:04 -04:00
7d8317a63e Support for Mixed Input TensorOp (#1084)
* Passing warp-level mixed input F16*(S8/U8) tests

* passing device-level mixed input F16*(S8/U8) tests

* add to profiler - I8 (111 TFLOPs), U (123 TFLOPs)

* fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs)

* Speedup reference compilation (REVERT THIS COMMIT)

* wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s)

* Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs

* BF16 * S8 (142 TFLOPs)

* Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16]

* rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast

* Add device-level test and profiler support for upcast on operand A

* Move shfl before the cvt and reduce #shfls by 1/2

* fix smem_usage calculation for mixed_input types

* uncomment the stuff (getting ready for merge)

* profiler changes and mixed-input reference

* mixed input reference are in a new file

* use platform instead of std

* comments and typo only

* Use CreateGemmOperator and delete CreateMixedInputGemmOperator

* copyright for new files

* rebase follow-up
2023-09-27 11:18:30 -04:00
5cd735c48e Fix Parallel Split-K on Gemm Operation Profiler (#1109)
* Debug and fix for parallel split-k in profiler

* restore debug files and remove prints
2023-09-26 17:28:00 -04:00
67ae8e0603 Change the position of minus sign in line1549 array.h (#1091)
when I use cutlass::epilogue:🧵:LinearCombinationSigmoid, I encounter the this error:
cutlass/include/cutlass/array.h(1549): error: no operator "-" matches these operands
Moving  operator "-" from line 1549 to 1548 can solve this error
2023-09-26 17:26:39 -04:00
14f69bddc8 [fix] fix comparison operator for integer_subbyte (#1090) 2023-09-26 17:26:12 -04:00
90d3b0fb18 CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
2023-09-26 17:24:26 -04:00
e0aaa3c3b3 fix GmmaDescriptor print format string error (#1102) 2023-09-19 23:27:58 -04:00
8783c41851 Replace 0x1f with 0xffffffff in __shfl_sync (#1097)
This fixes compatibility with H100 and resolves #1094
2023-09-18 19:58:19 -04:00
6407bcdf0a fix matrix B indices (#1089) 2023-09-12 14:04:18 -04:00
a77b2c9cb8 style(examples): typo (#1080)
* Update ampere_tensorop_conv2dfprop.cu

learning cutlass, PR a typo.

* Update ampere_gemm_operand_reduction_fusion.cu
2023-09-11 10:13:22 -04:00
34bbadd3ff standarize fp8 generator (#1078) 2023-09-07 14:36:33 -04:00
88c0d7c726 make only visible on device (#1071) 2023-09-07 13:00:46 -04:00
e01b9b5029 Shard gemm reference templates into multiple TUs for parallel compilation (#1043)
* Split apart gemm reference templates into multiple TUs for parallel compilation

* remove old files

* better balancing of ref kernels across TUs

* remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size

* remove auto fp8 kernels

* remove some redundant kernels
2023-08-30 16:46:30 -04:00
34fd98056b fix cinttypes issue with STDC_FORMAT_MACROS (#1068)
* fix cinttypes issue with STDC_FORMAT_MACROS

* Update mma_sm90_desc.hpp

* Update mma_sm90_desc.hpp

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2023-08-29 14:59:33 -04:00
3a8f57a3c8 Add simple hash and eq methods for gemm_operations. (#1053) 2023-08-27 20:41:57 -04:00
6673df0e48 fix typos (#1059) 2023-08-27 00:49:26 -04:00
7618e9bfd8 Fix numeric conversion warning (#1021)
* fix numeric conversion unused var

* update

---------

Co-authored-by: Lufang CHEN 陈橹方 <lufang.chen@nio.com>
2023-08-27 00:42:44 -04:00
a88c41cf8d Updates for 3.2 release (#1065) 2023-08-25 23:05:46 -04:00
27de343535 Add one Publication which is inspired by cutlass (#1022) 2023-08-22 10:00:17 -04:00
2a9fa23e06 Avoid cute::print compiler warnings with -Wformat-security (#1041)
Fixes issue #1040.
2023-08-18 14:38:27 -04:00
2e56cfabee fix typo (#1047) 2023-08-18 14:08:26 -04:00
3930f709ce Fix typo in 0x_gemm_tutorial.md (#1035) 2023-08-17 10:52:20 -04:00
7e5ee8b7bf [doc] fix: fix typos in the comment (#1049) 2023-08-16 11:39:25 -04:00
2d9a557427 torch.bfloat16 support in cutlass python (#1037)
* torch.bfloat16 support in cutlass python

* Update datatypes.py
2023-08-16 11:38:53 -04:00
4575443d44 CUTLASS 3.2 (#1024)
* CUTLASS 3.2
2023-08-07 20:50:32 -04:00
a0d787b746 Fix one publication (#1019) 2023-07-28 11:40:17 -04:00
d20f3a9542 spelling (#1007)
logicial -> logical
2023-07-20 14:41:11 -04:00
8e85580859 fix layout bug (#1006) 2023-07-19 14:26:01 -04:00
146d314057 Update fMHA kernels (#992)
* Update fMHA kernels

Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798

* minor changes

* make var work

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-07-12 22:30:46 -04:00
f679663224 Add RMS norm (#979) 2023-07-10 21:31:27 -04:00
e066ced33b fix epilogue iterator error (#995)
* fix epilogue iterator error

* fix epilogue iterator error

---------

Co-authored-by: maxiao <maxiao@cowarobot.com>
2023-07-10 21:30:31 -04:00
9b923dd4c4 fix minor typos (#984) 2023-07-05 09:23:01 -04:00
f6d42f2dd0 add library_dirs (#977) 2023-06-14 12:09:12 -04:00
473a67073e Fix Int8 and TF32 generator (#976) 2023-06-12 12:32:52 -04:00
87349d3496 Add grouped b2b GEMM (#970) 2023-06-05 17:16:57 -04:00
fde824af21 Update Hopper performance plot for CUTLASS 3.1 + CTK 12.1 (#967) 2023-06-01 14:52:40 -04:00
7dbf423763 Add conversion from ElementBias to ElementCompute (#961) 2023-05-26 23:08:36 -04:00
6f47420213 Update README.md 2023-05-24 12:40:31 -04:00
4638250469 Update CHANGELOG.md 2023-05-24 12:39:42 -04:00
7859fe322a Update PUBLICATIONS.md 2023-05-24 12:36:12 -04:00
d3e72719b4 Add support for sparse GEMM with row broadcasted bias vector (#951) 2023-05-24 10:25:05 -04:00
b4ab501767 Adds CUDA path for x86-64 (#957) 2023-05-24 10:21:25 -04:00
f079619f5e More updates for 3.1 (#958)
* Updates for 3.1

* Minor change

* doc link fix

* Minor updates
2023-05-24 10:17:16 -04:00
13f413493a Stream-K with broadcast (#892)
* [WIP] GEMM StreamK w/ Fused Epilogue

* Adds Gemm Streamk with Fused Epilogue kernel level struct.
  * Mostly based on Gemm with Fused Epilogue,
  * Requires a new epilogue
  * Work in progress

* [WIP] StreamK support for GemmUniversalWithBroadcast

* Just based off of how StreamK is allowed in GemmUniversal
  * Untested and a work in progress

* Minor fixes

* [WIP] It compiles!

It is almost certainly incorrect, but we're past getting the templates
to match, so checkpointing.

* Correction to reference kernel

* Fix typo

* Added MSE measurement

* Switch back to reference kernel + host for loop

Still WIP. Now we're getting even a larger MSE, but it's both on
basic Split-K and Stream-K.

* Fix typos

* Fix broadcast vector + requested changes

* Comment typo

* Small int option and more

* Fix incorrect condition on source needed

* Requested changes

* I think I got it?

* Bias vector should be stride 0

* Two source added!

* Typos

* Merge examples

* Bring back vector row offset

Just to ensure consistency with universal gemm with fused epilogue

* Base arguments and params structs for StreamK

* StreamK epilogue with broadcast now inherits the original

* undo params_streamk_base.h

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-22 19:05:06 -04:00
6fbc0d3380 Update layout.md 2023-05-17 20:12:58 -04:00
b97404837e Adding 128x256 tile for 16b input datatype WGMMA gemm (#950) 2023-05-17 17:13:23 -04:00
e2953d47c5 Update gemm_api.md 2023-05-12 15:37:31 -04:00
wll
19c4a4815e replace division with multiplication in GELU (#942) 2023-05-12 10:57:18 -04:00
fcfbd23e26 Fix host compilation of cute::cast_smem_ptr_to_uint. (#940)
* Remove references to device-only intrinsics when compiling for host.

Currently, we attempt to use the `__device__`-only functions
`__cvta_generic_to_shared` and `__nvvm_get_smem_pointer` when compiling
`cute::cast_smem_ptr_to_uint` for the host on Clang. This results in a
compilation error, as expected. This commit changes the definition of
the `*_ACTIVATED` macros so that they are only true when `__CUDA_ARCH__`
is defined; that is, when compiling for the device.

Additionally, the declaration of `__nvvm_get_smem_pointer`
is currently only visible during the device compilation pass when
compiling with NVCC; this commit makes the declaration visible during
host compilation with the `__device__` annotation.

* Annotate cute::cast_smem_ptr_to_uint as device-only.

The implementation of `cute::cast_smem_ptr_to_uint` is currently an
unchecked failure on host code, and the only host implementation I can
think of -- casting a probably-64-bit pointer to 32 bits somehow --
doesn't make sense to implement. This commit marks this function as
device-only so that it can't be accidentally used on host code.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-10 00:06:54 -04:00
b250faccd3 Make operator() const-correct and add missing static functions. (#936)
* Make operator() const-correct and add missing static functions.

Currently, `*Converter::operator()` requires a mutable object to invoke,
and there are missing `static result_type convert(source_type const &
source)` overloads for certain partial specializations of `*Converter`
objects. This commit makes `operator()` const-correct and adds missing
function overloads where appropriate.

* minor changes

* format

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-09 16:33:01 -04:00
24c8b7d8a2 Fix cuTE compilation with clang (#939)
- clang 1.14 complains about missing function from a host call:
  cutlass/include/cute/arch/util.hpp:106:32: error: no matching function for call to '__cvta_generic_to_shared'
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
- fixes this by defining CUTE_HOST_DEVICE for clang as well

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
2023-05-09 09:51:45 -04:00
7c04f95415 Updates for 3.1 (#932) 2023-04-29 09:34:27 -04:00
6f8596ce3f Add missing #include directive to get access to cutlass::epilogue:🧵:ScaleType. (#925)
Currently, the `LinearCombinationClamp` header file is not standalone,
and must have the definition of `cutlass::epilogue:🧵:ScaleType`
already available when it is `#include`d.
2023-04-28 20:02:41 -04:00
fe2f491dd7 Get SM count with cudaDeviceGetAttribute in KernelHardwareInfo (#927) 2023-04-28 13:23:23 -04:00
df02482f1d Add missing schedules argument in SM90 fp16 op generation (#920) 2023-04-26 16:44:49 -04:00
180c5629bf Add missing checks for NVRTC in CuTe (#921) 2023-04-25 12:52:43 -04:00
e36912f961 Fix for dangling references in the MHA example (#918) 2023-04-19 21:35:46 -04:00
9a83bd3381 CUTLASS 3.1 Python interface documentation (#917)
* Add 12.1 Dockerfile

* Add 3.1 docs
2023-04-18 15:11:35 -04:00
54bebe417d Fix some typos in CuTe tutorials (#912) 2023-04-17 16:00:51 -04:00
43cfbe0086 Allow L2 prefect for clang compiler (#914) 2023-04-15 01:23:22 -04:00
4a68cf748e added support of b2b bmm (#849)
* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-04-14 23:20:02 -04:00
d572cc1aab CUTLASS 3.1 (#915)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-04-14 23:19:34 -04:00
9b8166e3f0 fMHA: Add backward pass (#844)
* fMHA: Add backward pass

* Better checks for strides/alignments

* Remove fb-internal URL

* torch.Tensor.untyped_storage requires pytorch 2.0+

* minor changes

* make test

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-04-06 20:44:58 -04:00
e2d439ee7e Add tile_n=32 and tile_k=32 kernels in generator.py (#858) 2023-04-06 10:00:52 -04:00
0435979f59 Remove const from 3.x GemmUniversalAdapter::operator() (#905) 2023-04-03 20:30:51 -04:00
2ba1ef10be Increase max dynamic SMEM size in GemmSoftmax (#903) 2023-04-03 10:01:12 -04:00
0964bdb64c update gemm and conv2d cmdline --help output (#878) 2023-04-01 11:38:13 -04:00
ecbd24566c Enable shared memory intrinsics and ldmatrix PTX on Clang. (#754)
* Enable shared memory intrinsics and ldmatrix PTX on Clang.

This commit adds preprocessor checks to enable the shared memory
intrinsics `__cvta_generic_to_shared` and `__nvvm_get_smem_pointer`, as
well as the `ldmatrix` PTX instructions, on Clang. Preventing these
intrinsics from being used is a significant latency regression on Clang.

* refine the macro

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-03-31 21:42:24 -04:00
660a05f581 fix split_k_mode and add reduction kernel for f16 input/accum/output (#896) 2023-03-30 15:31:08 -04:00
bc36122c3f [layout] Fix AffineRank2ColumnMajor::packed() (#879)
* [layout] Fix AffineRank2ColumnMajor::packed()

* correct affine2row::packed

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-03-29 11:59:48 -04:00
15d9d31f1f CUTLASS 3.0 Hopper GEMMs are GETTs in disguise (#897) 2023-03-29 10:42:40 -04:00
1eef5c3cf1 add guards for __CUDA_ARCH__ >= 530 (#891)
* add guards for sm>=70

* drop guard to 530
2023-03-28 17:47:10 -04:00
87070b6d51 add a CUTLASS publication (#893)
* add bytetransformer

* update arxiv link

* re-order
2023-03-28 17:06:57 -04:00
77549ae6c8 Update PUBLICATIONS.md
msft moe paper
2023-03-25 21:17:05 -04:00
42290f5d1c Fix for dangling pointers (#885) 2023-03-25 01:15:14 -04:00
209faf7b94 remove spurious comma (#871) 2023-03-20 17:25:27 -04:00
6116706c96 Set batch_strides on Params::update (#883) 2023-03-20 17:07:47 -04:00
2670b973dd Fix sign-compare warning in reorder_array (#869)
`std::vector<T>::size_type` is unsigned type, so let's iterate over unsigned type as well


Discovered, while trying to enable PyTorch building without `-Wno-sign-compare` warning suppression, see https://github.com/pytorch/pytorch/actions/runs/4418987999/jobs/7746850762#step:10:10532
2023-03-20 17:07:24 -04:00
af332d4aa9 Add missing comma in cutlass/arch/mma_sm90.h (#862) 2023-03-14 12:04:28 -04:00
2036 changed files with 167089 additions and 46016 deletions

5
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,5 @@
blank_issues_enabled: true
contact_links:
- name: CUTLASS Discord
url: https://discord.gg/nvidiadeveloper
about: Come chat about using and contributing to CUTLASS!

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
# PyCache files
__pycache__/
cutlass_library.egg-info/

View File

@ -1,5 +1,78 @@
# NVIDIA CUTLASS Changelog
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
* Profiler support for lower-aligned Hopper GEMMs.
* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion).
* Sub-Byte type fixes and improvements.
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25)
* Minor patch for issue/1138
## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
* SM80 EVT support in C++ and Python.
* Other SM90 epilogue improvements.
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
* SM90 TF32 kernel improvements for all layouts.
* SM90 rasterization direction support in the CUTLASS profiler.
* Improvement for CUTLASS profiler build times.
* Remove Python-C++ bindings.
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
@ -45,7 +118,7 @@
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
@ -57,13 +130,13 @@
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
* [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py)
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py)
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/symm_operation.py)
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/trmm_operation.py)
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
@ -94,7 +167,7 @@
* **TF32x3:** emulated single-precision using Tensor Cores
* 45+ TFLOPs on NVIDIA A100
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu) (complex)
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
@ -146,7 +219,7 @@
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu)
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu)
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
* Many improvements to the epilogue.
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
@ -298,7 +371,7 @@
## Copyright
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -26,7 +26,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
cmake_minimum_required(VERSION 3.19 FATAL_ERROR)
cmake_policy(SET CMP0112 NEW)
if(cutlass_LOADED)
# If CUTLASS has been previously fetched and loaded, don't do it again.
@ -39,7 +40,25 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 3.0.0 LANGUAGES CXX)
# To reduce duplicate version locations, parse the version out of the
# main versions.h file and reuse it here.
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS)
string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1})
message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}")
## CUTLASS PROJECT #############################################################
project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX)
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 11.3)
@ -58,6 +77,8 @@ endif()
find_package(Doxygen QUIET)
################################################################################
#
# CUTLASS 3.x requires C++17
#
@ -79,16 +100,28 @@ endif()
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
# 0 - Sanity, 1 - Release-Quality, 2 - Exhaustive
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
################################################################################
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
if(CUTLASS_ENABLE_HEADERS_ONLY)
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
set(CUTLASS_ENABLE_LIBRARY_INIT OFF)
set(CUTLASS_ENABLE_TESTS_INIT OFF)
else()
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
set(CUTLASS_ENABLE_LIBRARY_INIT ON)
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ON)
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
endif()
set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.")
@ -97,19 +130,11 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Performance")
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
if (CUTLASS_ENABLE_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
endif()
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
################################################################################
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
@ -124,6 +149,17 @@ endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
# Find unsupported and deprecated compute capabilities
if (CUTLASS_NVCC_ARCHS_SUPPORTED)
set(CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS})
list(REMOVE_ITEM CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS_SUPPORTED})
if (CUTLASS_NVCC_ARCHS_UNSUPPORTED)
message(WARNING "Using unsupported or deprecated compute capabilities ${CUTLASS_NVCC_ARCHS_UNSUPPORTED}. Support may be removed in future versions.")
endif()
else()
message(WARNING "No supported compute capabilities for CUDA ${CUDA_VERSION}.")
endif()
# Special policy introduced in CMake 3.13
if (POLICY CMP0076)
cmake_policy(SET CMP0076 NEW)
@ -160,9 +196,12 @@ if(WIN32)
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
if (WIN32)
# Enable more warnings and treat as errors
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
# Disable warning on Unicode characters
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
@ -185,15 +224,42 @@ set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
################################################################################
#
# CUTLASS generator cmake configuration
#
# Kernel unified filter file
set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH")
if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS)
# If a kernel filter file is specified, we want to generate and then
# filter on the entire kernel set, not the default kernel
# (sub)set. The user may overried CUTLASS_LIBRRARY_KERNELS, in which
# case the resulting kernel set will be the intersection of the two
# options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS.
set(CUTLASS_LIBRARY_KERNELS_INIT "*")
else()
set(CUTLASS_LIBRARY_KERNELS_INIT "")
endif()
if (KERNEL_FILTER_FILE)
get_filename_component(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" ABSOLUTE)
set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE)
endif()
set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list")
if(KERNEL_FILTER_FILE)
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
endif()
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
################################################################################
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
@ -213,6 +279,8 @@ if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
endif()
################################################################################
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#
@ -262,6 +330,8 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
@ -287,9 +357,10 @@ if (CUTLASS_ENABLE_OPENMP_TESTS)
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
endif()
endif()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
if(UNIX)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
endif()
# Don't leak lineinfo in release builds
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
@ -352,6 +423,28 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
cmake_policy(SET CMP0104 NEW)
endif()
if (MSVC)
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
# between GCC/Clang and MSVC.
#
# error : a constexpr function cannot have a nonliteral return type "dim3"
#
# See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
endif()
# Some tests require this build option in order to link.
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj")
endif()
function(cutlass_apply_cuda_gencode_flags TARGET)
set(options)
set(oneValueArgs)
@ -466,7 +559,8 @@ endfunction()
# GLOB for CUTLASS header files. Should we use a static list instead?
file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h include/cutlass/*.hpp include/cutlass/*.inl)
file(GLOB_RECURSE CUTLASS_CUTE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cute/*.h*)
file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h)
###################################################################################################
@ -516,8 +610,8 @@ if (NOT DEFINED CUTLASS_REVISION)
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
@ONLY)
target_include_directories(
@ -526,11 +620,17 @@ target_include_directories(
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
)
# Mark CTK headers as system to supress warnings from them
target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
)
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
@ -587,6 +687,11 @@ endif()
include(CTest)
enable_testing()
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
endif()
if (NOT TARGET test_all)
add_custom_target(test_all)
endif()
@ -623,7 +728,13 @@ endif()
################################################################################
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake)
set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default
activated test sets. In `make test` mode, this string determines the
active set of tests. In `ctest` mode, this value can be overriden
with CUTLASS_TEST_SETS environment variable when running the ctest
executable.")
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
function(cutlass_add_executable_tests NAME TARGET)
@ -637,21 +748,33 @@ function(cutlass_add_executable_tests NAME TARGET)
# DEPENDS: A list of targets or files on which this test is dependent.
# DEPENDEES: A list of targets which should depend on this test.
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
# to pass to the test executable. A unique test is generated for each set of
# options given. If this option is not used, a single test with no arguments is generated.
# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for
# generating the full variable name to be referenced.
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
# test results to speed up test runtime.
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
#
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED __DISABLE_TESTS)
set(__DISABLE_TESTS OFF)
endif()
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
if (NOT DEFINED __TEST_SETS_SUPPORTED)
set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})
endif()
set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED})
if (__RESULT_CACHE_FILE)
add_custom_command(
@ -688,7 +811,6 @@ function(cutlass_add_executable_tests NAME TARGET)
endif()
list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT)
set(CMD_IDX 0)
if (CMD_COUNT GREATER 1)
add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS})
@ -697,12 +819,22 @@ function(cutlass_add_executable_tests NAME TARGET)
endforeach()
endif()
foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS})
if (CUTLASS_INSTALL_TESTS)
set(_INLINE_PER_TEST_CODE)
file(READ "${PROJECT_SOURCE_DIR}/cmake/CTestTestfile.test.configure.cmake" _INLINE_PER_TEST_CODE_TEMPLATE)
endif()
set(TEST_GROUP_NAME ${NAME})
foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS)
if (CMD_COUNT GREATER 1)
set(TEST_NAME ${NAME}_${CMD_IDX})
string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME)
else()
set(TEST_NAME ${NAME})
string(TOLOWER "${NAME}" TEST_NAME)
endif()
# The following rigmarole is needed to deal with spaces and possible quotes in
@ -711,14 +843,14 @@ function(cutlass_add_executable_tests NAME TARGET)
# preserves any quotes. Note, they have to be in this order for it to work for
# all the use cases below.
set(CMD_OPTIONS ${${CMD_OPTIONS}})
list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS)
separate_arguments(CMD_OPTIONS)
set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}})
list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS)
separate_arguments(TEST_COMMAND_OPTIONS)
add_custom_target(
${TEST_NAME}
COMMAND
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${TEST_COMMAND_OPTIONS}
DEPENDS
${TARGET}
)
@ -731,41 +863,48 @@ function(cutlass_add_executable_tests NAME TARGET)
add_dependencies(${DEPENDEE} ${TEST_NAME})
endforeach()
add_test(
NAME c${TEST_NAME}
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
)
set(TEST_NAME c${TEST_NAME})
string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY)
string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}")
set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS})
endforeach()
# To run the tests from an install package with tests enabled, we need to generate test files
# that don't rely on the current directory structure in build.
set(TEST_NAME c${NAME})
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})
set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
set(TEST_EXE_PATH $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format.
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY)
# The following line imports the tests for immediate run via `make test`.
include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake)
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
if (CUTLASS_INSTALL_TESTS)
# To run the tests from an install package with tests enabled, we need to generate test files
# that don't rely on the current directory structure in build.
file(GENERATE
OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in"
)
set(TEST_NAME c${TEST_NAME})
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY)
install(
FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/${TEST_NAME}
RENAME CTestTestfile.${TEST_NAME}.cmake
)
file(GENERATE
OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake"
)
install(
FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/
)
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
endif()
math(EXPR CMD_IDX "${CMD_IDX} + 1")
endforeach()
endfunction()
if (CUTLASS_ENABLE_TOOLS)
@ -774,6 +913,7 @@ if (CUTLASS_ENABLE_TOOLS)
add_dependencies(test_all test_profiler)
endif()
endif()
if (CUTLASS_ENABLE_EXAMPLES)
add_subdirectory(examples)
add_dependencies(test_all test_examples)
@ -781,38 +921,32 @@ endif()
if (CUTLASS_ENABLE_TESTS)
add_subdirectory(test)
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
add_dependencies(test_all test_unit)
endif()
endif()
if (CUTLASS_INSTALL_TESTS)
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake")
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")
file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n")
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
endforeach()
install(
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
FILES "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake"
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
)
endif()
#? install(
#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
#? )
#?
#? install(
#? DIRECTORY
#? ${CMAKE_BINARY_DIR}/tools
#? ${CMAKE_BINARY_DIR}/test
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
#? FILES_MATCHING PATTERN "CTestTestfile.cmake"
#? )
################################################################################
include(CMakePackageConfigHelpers)
@ -821,9 +955,15 @@ write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
COMPATIBILITY AnyNewerVersion)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
@ONLY
)
install(
FILES
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/
)
@ -838,3 +978,4 @@ install(
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake)

View File

@ -13,6 +13,7 @@ Cris Cecka<br />
Aniket Shivam<br />
Jack Kosaian<br />
Mark Hoemmen<br />
Richard Cai<br />
Honghao Lu<br />
Ethan Yan<br />
Haicheng Wu<br />
@ -21,6 +22,8 @@ Dustyn Blasig<br />
Fengqi Qiao<br />
Duane Merrill<br />
Yujia Zhai<br />
Rawn Henry<br />
Sergey Klevtsov<br />
Shang Zhang<br />
Piotr Majcher<br />
Paul Springer<br />
@ -55,6 +58,7 @@ Alan Kaatz<br />
Tina Li<br />
Timmy Liu<br />
Wei Liu<br />
Tim Martin<br />
Duane Merrill<br />
Kevin Siu<br />
Markus Tavenrath<br />

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -76,6 +76,7 @@ find_library(
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x86_64-linux-gnu
lib/x64
lib64
lib
@ -120,6 +121,7 @@ find_library(
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x86_64-linux-gnu
lib/x64
lib64
lib
@ -226,7 +228,14 @@ else()
endif()
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
if (MSVC)
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8)
else()
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16)
endif()
set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)
@ -239,11 +248,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
endif()
if (NOT DEFINED __BATCH_SOURCES)
set(__BATCH_SOURCES ON)
endif()
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
endif()
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)
set(CUDA_FILE_ARGS)
set(TARGET_SOURCE_ARGS)
@ -296,10 +309,10 @@ function(cutlass_add_library NAME)
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
endif()
cutlass_apply_standard_compile_options(${NAME})

View File

@ -1,4 +1,4 @@
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
Redistribution and use in source and binary forms, with or without

View File

@ -2,12 +2,31 @@
## 2023
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
- ["MegaBlocks: Efficient Sparse Training with Mixture-of-Experts"](https://arxiv.org/abs/2211.15841). Trevor Gale, Deepak Narayanan, Cliff Young, Matei Zaharia. _Proceedings of the Sixth Machine Learning and Systems_, May 2023.
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023.
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
## 2022
- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022.
- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022.
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
@ -18,7 +37,7 @@
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
## 2020

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 3.0
# CUTLASS 3.4
_CUTLASS 3.0 - January 2023_
_CUTLASS 3.4 - February 2024_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -31,33 +31,31 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
# What's New in CUTLASS 3.0
# What's New in CUTLASS 3.4
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
CUTLASS 3.4.1 is an update to CUTLASS adding:
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
for kernel procedural names.
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
CUTLASS 3.4.0 is an update to CUTLASS adding:
## New architecture, compiler, and CUDA Toolkit requirements
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
Minimum requirements:
@ -65,7 +63,7 @@ Minimum requirements:
- Compiler: Must support at least C++17
- CUDA Toolkit version: 11.4
CUTLASS 3.0 *removes support* for the following:
Starting from CUTLASS 3.0, CUTLASS removed support for the following:
- Maxwell and Pascal GPU architectures
- Ubuntu 16.04
@ -76,7 +74,7 @@ CUTLASS 3.0 *removes support* for the following:
# Performance
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
<p align="center"><img src=media/images/cutlass-3.1-gemm-peak-performance.png></p>
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit peak performance comparable to cuBLAS for scalar GEMM
@ -87,20 +85,21 @@ an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
as shown in the above figure. Tensor Core operations are still implemented using CUDA's
as shown in the above figure. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
# Compatibility
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
## Operating Systems
We have tested the following environments.
@ -110,8 +109,12 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
| Ubuntu 22.04 | Clang 10.0.0 |
| Ubuntu 22.04 | Clang 14.0.6 |
| Ubuntu 22.04 | Clang 17.0.6 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
## Hardware
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
@ -131,9 +134,9 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
## Target Architecture
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error.
```
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
@ -178,7 +181,8 @@ CUTLASS is a header-only template library and does not need to be built to be us
projects. Client applications should target CUTLASS's `include/` directory in their include
paths.
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
@ -514,7 +518,7 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
- [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
@ -529,7 +533,7 @@ The official list of CUTLASS developers and contributors is available here: [CON
# Copyright
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
@ -558,4 +562,3 @@ SPDX-License-Identifier: BSD-3-Clause
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -1,3 +1,31 @@
# Copyright (c) 2019 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# A small utility function which generates a C-header from an input file
function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
FILE(READ "${FILENAME}" HEX_INPUT HEX)
@ -6,7 +34,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
endif()
string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT})
set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")

View File

@ -1,21 +0,0 @@
# Generated file
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
else()
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
endif()
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
else()
set(TEST_EXE_PATH @TEST_EXE@)
endif()
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
endif()
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)

View File

@ -0,0 +1,54 @@
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Generated file
set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
endif()
foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
return()
endif()
endforeach()
set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
else()
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
endif()
@_INLINE_PER_TEST_CODE@

View File

@ -0,0 +1,43 @@
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
# The longform/extended format allows generator expressions to be
# expanded property and is useful in contexts where the files need
# to be immediately included into being-processed cmake code.
add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
else()
add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
endif()
if (TEST_EXE_WORKING_DIRECTORY)
set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
endif()
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)

View File

@ -2,6 +2,8 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
include(CMakeFindDependencyMacro)
if(NOT TARGET nvidia::cutlass::CUTLASS)
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
if(TARGET nvidia::cutlass::CUTLASS)
return()
endif()
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")

View File

@ -1,3 +1,31 @@
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(CPACK_PACKAGE_NAME NvidiaCutlass)
set(CPACK_PACKAGE_VENDOR NVIDIA)
set(CPACK_PACKAGE_CONTACT info@nvidia.com)

View File

@ -1,3 +1,31 @@
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(FetchContent)
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
@ -9,7 +37,7 @@ endif()
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG 0fe9660
GIT_TAG v1.13.0
)
FetchContent_GetProperties(googletest)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,38 +0,0 @@
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -28,18 +28,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Bind gemm test to python
*/
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include "host.h"
namespace py = pybind11;
void bind_gemm_test(py::module &m) {
py::module_ host_submodule = m.def_submodule("host");
bind_gemm_host_reference(host_submodule);
}
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_REVISION "@CUTLASS_REVISION@"

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1677,7 +1677,7 @@ template&lt;typename Element , typename Layout &gt; </div>
</tr>
</table>
</div><div class="memdoc">
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
</div>
</div>

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -31,4 +31,5 @@
cutlass_example_add_executable(
02_dump_reg_shmem
dump_reg_shmem.cu
DISABLE_TESTS ON
)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -140,8 +140,8 @@ using ElementInputA = int8_t; // <- data type of elements
using ElementInputB = int8_t; // <- data type of elements in input matrix B
using ElementOutput = int32_t; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
// The code section below describes matrix layout of input and output matrices. Row Major for
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
@ -355,4 +355,3 @@ int main() {
return run();
}

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = int32_t; // Data type of accumulator
@ -555,6 +554,7 @@ Result profile_convolution(Options const &options) {
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
ElementOutput,
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
@ -674,7 +674,6 @@ Result profile_convolution(Options const &options) {
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
@ -761,11 +760,7 @@ int main(int argc, char const **args) {
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -27,7 +27,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This example depends on the CUTLASS Library
#
if (CUTLASS_ENABLE_LIBRARY)
# Planar Complex GEMM example
cutlass_example_add_executable(
@ -35,11 +38,6 @@ cutlass_example_add_executable(
planar_complex.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
10_planar_complex
PRIVATE
@ -48,3 +46,4 @@ target_link_libraries(
cuda
)
endif()

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -27,7 +27,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This example depends on the CUTLASS Library
#
if (CUTLASS_ENABLE_LIBRARY)
# Planar Complex Array GEMM example
cutlass_example_add_executable(
@ -35,11 +38,6 @@ cutlass_example_add_executable(
planar_complex_array.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
11_planar_complex_array
PRIVATE
@ -48,3 +46,4 @@ target_link_libraries(
cuda
)
endif()

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -81,7 +81,7 @@ using ShapeMMAThreadBlock =
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??

View File

@ -1,5 +1,5 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
@ -64,6 +64,7 @@ endforeach()
foreach(FUSION_GEMM_EXAMPLE
fused_two_gemms_f16_sm75_rf
fused_two_gemms_f16_sm75_shmem
fused_two_gemms_grouped_f16_sm80_rf
fused_two_gemms_f16_sm80_rf
fused_two_gemms_f16_sm80_shmem
fused_two_gemms_s8_sm75_rf

View File

@ -1,11 +1,11 @@
# Introduction
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
<p align="center"><img src=/media/images/13_example_fusion.png></p>
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
stores the result activation matrix back to the memory.
When the two GEMM/Conv operations are fused together, the mainloops of the two
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
2 GEMMs/Convs.
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
input activation, the example enforces the following two constraints:
- thread_block_tile_N = problem_N
- thread_block_tile_N = problem_N
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
operation can be fully block-resident.
- warp_tile_N = thread_block_tile_N
- warp_tile_N = thread_block_tile_N
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
@ -82,11 +82,11 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
# Copyright
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -42,6 +42,7 @@
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
//
B2bNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
//
B2bFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
ElementCompute,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()

View File

@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Containers for running grouped back-to-back GEMMs
*/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
////////////////////////////////////////////////////////////////////////////////
template <typename B2bGemm_>
struct B2bFusedGroupedGemmRun
{
using B2bGemm = B2bGemm_;
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed;
//
// Methods
//
B2bFusedGroupedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
int problem_count = (int)problem_sizes_0.size();
std::vector<HostTensorA> host_tensor_A0(problem_count);
std::vector<HostTensorB> host_tensor_B0(problem_count);
std::vector<HostTensorC> host_tensor_C0(problem_count);
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
std::vector<HostTensorB> host_tensor_B1(problem_count);
std::vector<HostTensorC> host_tensor_C1(problem_count);
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
std::vector<HostTensorC> host_tensor_D1(problem_count);
std::vector<HostTensorZ> host_tensor_Z(problem_count);
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
for (int i = 0; i < problem_count; ++i) {
//
// Allocate the GEMM workspace
//
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
if (alpha0 == ElementCompute(0)) //per-channel scale
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
cutlass::reference::host::TensorFill(
host_tensor_D1.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D0.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D1.at(i).host_view());
host_tensor_A0.at(i).sync_device();
host_tensor_B0.at(i).sync_device();
host_tensor_C0.at(i).sync_device();
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i).sync_device();
host_tensor_Bias0.at(i).sync_device();
host_tensor_B1.at(i).sync_device();
host_tensor_C1.at(i).sync_device();
host_tensor_Bias1.at(i).sync_device();
host_tensor_D1.at(i).sync_device();
host_tensor_ref_D0.at(i).sync_device();
host_tensor_ref_D1.at(i).sync_device();
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
if (alpha0 == ElementCompute(0)) //per-channel scale
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
}
//
// Initialize the GEMM operator
//
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
device_ref_A0.copy_from_host(ref_A0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
device_ref_B0.copy_from_host(ref_B0.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
device_ref_C0.copy_from_host(ref_C0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
device_ref_Scale0.copy_from_host(ref_Scale0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
device_ref_Bias0.copy_from_host(ref_Bias0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
device_ref_B1.copy_from_host(ref_B1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
device_ref_C1.copy_from_host(ref_C1.data());
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
device_ref_Bias1.copy_from_host(ref_Bias1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
device_ref_D1.copy_from_host(ref_D1.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
B2bGemm b2b_gemm_op;
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
if (!threadblock_count) {
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
return false;
}
typename B2bGemm::Arguments arguments{
problem_count,
device_problem_sizes_0.get(),
device_problem_sizes_1.get(),
device_ref_A0.get(),
device_ref_B0.get(),
device_ref_C0.get(),
device_ref_Scale0.get(),
device_ref_Bias0.get(),
device_ref_B1.get(),
device_ref_C1.get(),
device_ref_D1.get(),
{alpha0, beta0},
{alpha1, beta1},
threadblock_count
};
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
if(status != cutlass::Status::kSuccess) {
std::cout << "Problem sizes not supported.\n"
<< "Requirments:\n"
<< " problem_size_0.M = problem_size_1.M\n"
<< " problem_size_0.N = problem_size_1.K\n"
<< " ThreadblockShape0::kN = problem_size_0.N\n"
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
}
status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < runs; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
for (int i = 0; i < problem_count; ++i) {
host_tensor_D1.at(i).sync_host();;
//
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator>
reference_gemm_1;
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
ref_A0.at(i),
ref_B0.at(i),
ElementAccumulator(0), //beta = 0
ref_Z.at(i),
ref_Z.at(i),
ElementAccumulator(0)
);
cutlass::reference::device::TensorScaleBiasGemm<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutC
> (
problem_size_0,
ref_Z.at(i),
ref_ref_D0.at(i),
alpha0,
ref_Scale0.at(i),
ref_Bias0.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
ref_ref_D0.at(i),
ref_B1.at(i),
beta1,
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
ref_ref_D1.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
}
cudaDeviceSynchronize();
host_tensor_ref_D0.at(i).sync_host();
host_tensor_ref_D1.at(i).sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
host_tensor_ref_D1.at(i).host_view(),
host_tensor_D1.at(i).host_view());
CHECK_TRUE(passed);
if (!passed)
{
std::stringstream fname;
fname << "error_B2bGemm_device_fused.txt";
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "GEMM " << i << " in group\n"
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
return false;
}
}
return true;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -43,6 +43,7 @@
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
//
B2bInterleavedNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
//
B2bInterleavedFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
//Reorder B0
cutlass::reorder_column<16>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
// tensor_Bias0_batched.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
tensor_B1_reordered.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -119,8 +119,6 @@ template <
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@ -154,7 +152,6 @@ class B2bGemm {
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
@ -184,77 +181,11 @@ class B2bGemm {
EpilogueOutputOp1,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
SmemAccumulator
>::B2bGemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int split_k_slices_ = 1
):
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
split_k_slices(split_k_slices_) {
}
};
using Arguments = typename B2bGemmKernel::Arguments;
private:
@ -269,10 +200,6 @@ public:
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = B2bGemmKernel::can_implement(
args.problem_size_0,
args.problem_size_1,
@ -295,20 +222,14 @@ public:
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
args.batch_count);
return bytes;
}
@ -320,38 +241,17 @@ public:
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
args.batch_count);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
// args.problem_size_1,
// args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
// args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// args.batch_count);
// Initialize the Params structure
params_ = typename B2bGemmKernel::Params{
args.mode,
args.problem_size_0,
args.problem_size_1,
grid_shape,
@ -363,6 +263,13 @@ public:
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.batch_stride_A0,
args.batch_stride_B0,
args.batch_stride_B1,
args.batch_stride_C1,
args.batch_stride_D1,
args.batch_stride_Bias0,
args.batch_stride_Scale0,
args.epilogue0,
args.epilogue1,
static_cast<int *>(workspace),
@ -373,12 +280,6 @@ public:
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
@ -430,12 +331,12 @@ public:
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
return pass;
}
int main() {
std::vector<bool (*)()>funcs = {
@ -229,10 +228,6 @@ int main() {
};
return testRun(75, funcs, "conv int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -39,7 +39,6 @@
#include "device/b2b_implicit_gemm_convolution.h"
#include "b2b_interleaved_conv2d_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
return pass;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
};
return testRun(75, funcs, "conv int8 shmem staging");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -0,0 +1,297 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
*/
#include <iostream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/base_grouped.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "kernel/default_b2b_gemm.h"
#include "threadblock/grouped_threadblock_swizzle.h"
#include "b2b_grouped_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
// Constraints:
// 1. Warp shape N must equal thread block shape N
// 2. Problem size N must equal thread block shape N
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
int alignment = 8;
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
int problem_count;
bool verbose;
//
// Methods
//
Options():
help(false),
error(false),
reference_check(true),
problem_count(15),
verbose(false)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("problems", problem_count, 15);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("verbose", verbose, false);
randomize_problems(cmd);
}
void randomize_problems(cutlass::CommandLine &cmd) {
//
// For now, randomly choose the problem sizes.
//
int cmd_line_m = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes0.reserve(problem_count);
problem_sizes1.reserve(problem_count);
for (int i = 0; i < problem_count; ++i) {
int m = cmd_line_m;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 256) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 256) + 1);
}
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
problem_sizes0.push_back(problem0);
problem_sizes1.push_back(problem1);
}
if (verbose) {
print_problem_sizes();
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
<< " back-to-back GEMMs are subject to.s"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
<< " --verbose=<bool> If true, prints problem sizes.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
return out;
}
void print_problem_sizes() {
std::cout << std::endl;
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
for (int i = 0; i < problem_count; ++i) {
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
std::cout << "Problem " << i
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
<< std::endl;
}
}
};
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
InstructionShape::kM * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
ThreadblockShape0,
cutlass::layout::RowMajor // LayoutC
>;
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
const int kStages = 3;
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
cutlass::half_t,
cutlass::layout::RowMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::ColumnMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
GroupedThreadblockSwizzle,
kStages,
cutlass::arch::OpMultiplyAdd
>::B2bGemmKernel;
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main(int argc, char const **args) {
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
std::vector<bool (*)()>funcs = {
&run_fused_grouped_gemm_f16_sm80_rf_res
};
return testRun(80, funcs, "grouped gemm f16 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() {
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
@ -204,9 +203,6 @@ int main() {
};
return testRun(75, funcs, "gemm int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -43,7 +43,6 @@
#include "device/b2b_gemm.h"
#include "b2b_interleaved_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() {
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8,
&run_fused_gemm_s8_shmem
};
return testRun(75, funcs, "gemm int8 shmem staing");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1
);
if(passed)
std::cout << "Pass\n";
else
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
return passed;
}
bool run_fused_gemm_s8_sm80_rf_res_batch() {
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = false;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
SmemAccumulator,
16,
16,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
int batch_count = 2;
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_Scale0 = 0;
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1,
cutlass::gemm::GemmUniversalMode::kBatched,
batch_count,
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0
);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8_sm80,
&run_fused_gemm_s8_sm80_rf_res
&run_fused_gemm_s8_sm80_rf_res,
&run_fused_gemm_s8_sm80_rf_res_batch
};
return testRun(80, funcs, "gemm int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = true;
using B2bGemm = cutlass::gemm::device::B2bGemm<
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -40,19 +40,66 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
#include "threadblock/grouped_threadblock_swizzle.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
namespace detail {
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
/// the parameters of the problem visitor used in GroupedParams.
template <
typename B2bMma_,
typename ThreadblockSwizzle_,
typename Enable = void
>
struct ProblemVisitorOrDefault;
/// Return a generic problem visitor for GEMM problems
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
GroupScheduleMode::kDeviceOnly,
128,
128,
platform::is_same<typename B2bMma_::LayoutC,
cutlass::layout::ColumnMajor>::value>;
};
/// Return the problem visitor specified by the swizzling function
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = typename ThreadblockSwizzle_::ProblemVisitor;
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct B2bGemm {
@ -61,14 +108,184 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA0 = typename B2bMma::IteratorA0::Element;
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
using ElementB0 = typename B2bMma::IteratorB0::Element;
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
using ElementB1 = typename B2bMma::IteratorB1::Element;
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
/// the second GEMM (other than for ElementA/ElementB)
using ElementA = typename B2bMma::IteratorA0::Element;
using LayoutA = typename B2bMma::IteratorA0::Layout;
using ElementB = typename B2bMma::IteratorB0::Element;
using LayoutB = typename B2bMma::IteratorB0::Layout;
static ComplexTransform const kTransformA = B2bMma::kTransformA;
static ComplexTransform const kTransformB = B2bMma::kTransformB;
using Operator = typename B2bMma::Operator0;
using OperatorClass = typename Operator::OperatorClass;
using ThreadblockShape = typename B2bMma::Shape0;
using WarpShape = typename Operator::Shape;
using InstructionShape = typename Operator::InstructionShape;
using ArchTag = typename B2bMma::ArchTag;
static int const kStages = B2bMma::kStages;
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
using Mma = B2bMma;
using EpilogueOutputOp = OutputOp1;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
static int const kThreadCount = 32 * WarpCount0::kCount;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
typename B2bMma::IteratorA0::TensorRef ref_A0_,
typename B2bMma::IteratorB0::TensorRef ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
batch_count(batch_count_) {
}
};
// Arguments structure for grouped B2B problems
struct GroupedArguments {
GemmCoord* problem_size_0;
GemmCoord* problem_size_1;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int problem_count;
int threadblock_count;
GemmCoord* host_problem_sizes;
CUTLASS_HOST_DEVICE
GroupedArguments(
int problem_count,
GemmCoord* problem_size_0_,
GemmCoord* problem_size_1_,
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int threadblock_count = 0
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
problem_count(problem_count),
threadblock_count(threadblock_count)
{}
};
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
@ -89,6 +306,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
@ -100,11 +324,12 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
@ -116,14 +341,22 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
int64_t batch_stride_A0,
int64_t batch_stride_B0,
int64_t batch_stride_B1,
int64_t batch_stride_C1,
int64_t batch_stride_D1,
int64_t batch_stride_Bias0,
int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
mode(mode),
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
@ -138,6 +371,13 @@ struct B2bGemm {
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
batch_stride_A0(batch_stride_A0),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C1(batch_stride_C1),
batch_stride_D1(batch_stride_D1),
batch_stride_Bias0(batch_stride_Bias0),
batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0),
output_op_1(output_op_1) {
@ -152,6 +392,81 @@ struct B2bGemm {
}
};
struct GroupedParams {
cutlass::gemm::GemmCoord* problem_size_0;
cutlass::gemm::GemmCoord* problem_size_1;
cutlass::gemm::GemmCoord* grid_tiled_shape;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int* workspace;
CUTLASS_HOST_DEVICE
GroupedParams() {}
CUTLASS_HOST_DEVICE
GroupedParams(
GroupedArguments const &args,
void *workspace = nullptr,
int tile_count = 0
) :
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
workspace(reinterpret_cast<int*>(workspace)) {}
CUTLASS_HOST_DEVICE
void transpose() {
// Only row-major outputs are currently supported, so no transpose is performed
}
/// Returns non-grouped paramaters to be used as input to the kernel-level
/// operator for the problem indicated by problem_visitor.
CUTLASS_HOST_DEVICE
Params to_single_params(const ProblemVisitor& problem_visitor) const {
GemmCoord problem_size0 = problem_visitor.problem_size0();
GemmCoord problem_size1 = problem_visitor.problem_size1();
int32_t idx = problem_visitor.problem_index();
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
return Params(
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size0,
problem_size1,
grid_shape,
ref_A0[idx],
ref_B0[idx],
ref_C0[idx],
ref_Scale0[idx],
ref_Bias0[idx],
ref_B1[idx],
ref_C1[idx],
ref_D1[idx],
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
output_op_0,
output_op_1,
workspace
);
}
};
/// Shared memory storage structure
union SharedStorage {
typename B2bMma::B2bMmaSharedStorage main_loop;
@ -163,7 +478,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
B2bGemm() { }
B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@ -223,7 +538,7 @@ struct B2bGemm {
if(problem_size_0.n() > B2bMma::Shape0::kN)
return Status::kErrorInvalidProblem;
if(problem_size_1.n() > B2bMma::Shape1::kN)
return Status::kErrorInvalidProblem;
@ -233,9 +548,13 @@ struct B2bGemm {
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
run_with_swizzle(params, shared_storage, threadblock_swizzle);
}
/// Executes one GEMM with an externally-provided swizzling function
CUTLASS_DEVICE
void run_with_swizzle(Params const &params, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
@ -247,37 +566,64 @@ struct B2bGemm {
return;
}
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
int offset_k_0 = 0;
int offset_k_1 = 0;
int problem_size_k_0 = params.problem_size_0.k();
int problem_size_k_1 = params.problem_size_1.k();
if (params.mode == GemmUniversalMode::kGemm) {
// Problem size is a function of threadblock index in the K dimension
problem_size_k_0 = min(
problem_size_k_0,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Problem size is a function of threadblock index in the K dimension
problem_size_k_1 = min(
problem_size_k_1,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1,
offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
@ -286,34 +632,33 @@ struct B2bGemm {
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
params.ref_A0.data(),
ptr_A0,
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
params.ref_B0.data(),
ptr_B0,
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
params.ref_B1.data(),
ptr_B1,
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(),
ptr_Scale0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -323,7 +668,7 @@ struct B2bGemm {
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(),
ptr_Bias0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -332,14 +677,17 @@ struct B2bGemm {
)
);
//
// Main loop
//
OutputOp0 output_op_0(params.output_op_0);
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
}
// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
@ -349,11 +697,9 @@ struct B2bGemm {
src_accum.clear();
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
}
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
//
// Epilogue
@ -376,23 +722,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
ptr_C1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
@ -401,21 +756,21 @@ struct B2bGemm {
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
ptr_D1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
@ -427,14 +782,14 @@ struct B2bGemm {
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -457,4 +812,3 @@ struct B2bGemm {
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,157 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Scheduler for grouped B2b GEMMs
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using BaseParams = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
static bool const kTransposed = Transposed;
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
struct Params {
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
problem_count(0), workspace(nullptr), tile_count(0) { }
/// Ctor
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const *problem_sizes0,
cutlass::gemm::GemmCoord const *problem_sizes1,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count)
{}
/// Convert the B2b-GEMM-specific parameters to those used by the base class
CUTLASS_HOST_DEVICE
BaseParams to_base() const {
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
// shape of the grid used in the non-grouped B2b GEMM
problem_sizes0,
problem_count,
workspace,
tile_count);
}
};
//
// Methods
//
CUTLASS_DEVICE
B2bGemmGroupedProblemVisitor(
Params const &params_,
SharedStorage &shared_storage_,
int32_t block_idx
): Base (
params_.to_base(),
shared_storage_, block_idx),
problem_sizes0(params_.problem_sizes0),
problem_sizes1(params_.problem_sizes1)
{}
/// Returns the problem size 0 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size0() const {
GemmCoord problem = problem_sizes0[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
/// Returns the problem size 1 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size1() const {
GemmCoord problem = problem_sizes1[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -63,7 +63,9 @@
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "kernel/b2b_gemm.h"
#include "kernel/grouped.h"
#include "threadblock/default_b2b_mma.h"
#include "threadblock/grouped_threadblock_swizzle.h"
////////////////////////////////////////////////////////////////////////////////
@ -73,6 +75,9 @@ namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <typename T>
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
template <
/// Element type for A matrix operand
typename ElementA_,
@ -114,12 +119,12 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
bool SmemAccumulator = false
bool SmemAccumulator = false,
/// Whether or not the operation is grouped
typename Enable = void
>
struct DefaultB2bGemm;
@ -161,17 +166,77 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
/// Partial specialization for Ampere Architecture with grouped operation
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
@ -188,7 +253,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
};
@ -228,8 +295,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -249,8 +314,9 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator
Operator,
false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
> {
/// Define the threadblock-scoped matrix multiply-accumulate
@ -274,7 +340,7 @@ struct DefaultB2bGemm<
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -287,7 +353,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -323,20 +389,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator> {
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -360,7 +423,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -396,19 +459,17 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
ThreadblockSwizzle, 2, Operator, false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -418,7 +479,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -430,7 +491,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -112,22 +112,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, true> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
true
> {
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
false,
true
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -277,20 +270,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, true> {
Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -350,19 +340,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief High-level interface for running a grouped version of a CUTLASS kernel
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// High-level interface for running a grouped version of a CUTLASS kernel
template <
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
>
struct GroupedKernel {
public:
using BaseKernel = BaseKernel_;
using Epilogue = typename BaseKernel::Epilogue;
/// Types that need to be exported to work properly with device::BaseGrouped
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
using Mma = typename BaseKernel::Mma;
using Arguments = typename BaseKernel::GroupedArguments;
using Params = typename BaseKernel::GroupedParams;
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
static int const kThreadCount = BaseKernel::kThreadCount;
/// Shared memory storage structure
struct SharedStorage {
typename BaseKernel::SharedStorage kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GroupedKernel() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
/// Executes a kernel-level GEMM in a loop
CUTLASS_DEVICE
void operator()(Params &params, SharedStorage &shared_storage) {
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
if (ProblemVisitor::kTransposed) {
params.transpose();
}
BaseKernel mma;
// Outer 'persistent' loop to iterate over tiles
while (swizzle.problem_visitor.next_tile()) {
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
// Next tile
swizzle.problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
MatrixCoord output_coord(
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kMblock = 4,
int kNblock = 4
>
__global__ void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
ConvertOp convert_op;
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
int batch_idx = blockIdx.z;
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
}
}
}
}
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
int const kMblock = 4;
int const kNblock = 4;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
batch_count % std::numeric_limits<uint16_t>::max()
);
kernel::TensorScaleBiasGemmBatched<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kMblock,
kNblock
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias,
batch_count,
batch_stride_tensor_in,
batch_stride_tensor_out,
batch_stride_tensor_scale,
batch_stride_tensor_bias
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -119,8 +119,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Policy describing tuning details
using Policy0 = Policy0_;
@ -139,6 +141,10 @@ public:
using IteratorB1 = IteratorB1_;
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
@ -188,6 +194,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {
@ -641,6 +651,11 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
// 2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
@ -871,7 +886,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -121,8 +121,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Iterates over tiles of the scale and bias vectors in global memory
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
///< Policy describing tuning details
@ -141,6 +143,10 @@ public:
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
@ -194,6 +200,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {
@ -664,6 +674,11 @@ public:
}
// Insert fence and wait for all outstanding cp.async operations to commit.
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
/// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0;
@ -855,7 +870,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -126,7 +126,9 @@ public:
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA0;
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB0;
using Policy0 = Policy0_; ///< Policy describing tuning details
using SmemIteratorA0 = SmemIteratorA0_;
@ -139,6 +141,8 @@ public:
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy describing tuning details
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
using Shape = Shape1;
using SmemIteratorB1 = SmemIteratorB1_;
@ -195,6 +199,10 @@ public:
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

Some files were not shown because too many files have changed in this diff Show More