Compare commits
596 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e51efbfe18 | |||
| fd6cfe1ed0 | |||
| 9baa06dd57 | |||
| ebe98c549a | |||
| 9892624b66 | |||
| a1aaf2300a | |||
| b995f93317 | |||
| 889ff20648 | |||
| dc4817921e | |||
| 5c6bca0441 | |||
| c2ad7c5b20 | |||
| cc23f6d1e9 | |||
| 5a287538c2 | |||
| 8bdbfca682 | |||
| 2e2af190bd | |||
| f12b1d75c9 | |||
| b244379d9b | |||
| 9d165a3b8e | |||
| b9b110a9ea | |||
| 8206e7a0f5 | |||
| 6316b6f867 | |||
| 9354bfd7c1 | |||
| 5e9b8e2a25 | |||
| 1ec230c4bf | |||
| f89cd95b16 | |||
| f115c3f854 | |||
| ad7b2f5e84 | |||
| 40f124ef27 | |||
| 89f6bf2739 | |||
| f535c33634 | |||
| e3cb8a773a | |||
| c4bdfe821c | |||
| b3ce7e12b7 | |||
| fe75ead92e | |||
| 35136f5564 | |||
| e5b810bed1 | |||
| 2b78c2fe31 | |||
| 697126019e | |||
| e94e888df3 | |||
| be73ad20a5 | |||
| f02a7c2976 | |||
| 331a1f5b3f | |||
| 8e345c5c5b | |||
| 81a43e6d92 | |||
| ade6376fa0 | |||
| bb4dd682dd | |||
| 5e497243f7 | |||
| b3f3c7758c | |||
| 9e1b649827 | |||
| 5120b21cc3 | |||
| dd76dec4ef | |||
| 19cc2a5feb | |||
| 09df6ac464 | |||
| df8a550d39 | |||
| 79fc51f4b8 | |||
| 6f4921858b | |||
| 62750a2b75 | |||
| 8c4d1dc47d | |||
| 3fe62887d8 | |||
| bd03b22f64 | |||
| 6c6b78550e | |||
| 06e560d98a | |||
| df18f5e4f5 | |||
| ca4fdbea70 | |||
| eefa171318 | |||
| afa1772203 | |||
| 9b3772dfa6 | |||
| b84e9802d8 | |||
| e9627ce55b | |||
| ad6e1ec19c | |||
| 0642d46dd4 | |||
| 833f6990e0 | |||
| affd1b693d | |||
| 6f55278121 | |||
| 3c28697b9f | |||
| bdd641790a | |||
| cc19d4d22b | |||
| 47daa33c61 | |||
| 389e493055 | |||
| 9eb01fa0b0 | |||
| b78588d163 | |||
| 902dff3663 | |||
| ef5620dd1d | |||
| 375e284e6a | |||
| 52b35e90ce | |||
| 24f991e879 | |||
| 51b25e7b58 | |||
| 7de6a59784 | |||
| c506e16788 | |||
| 7494a180a4 | |||
| cffd5d32b7 | |||
| bf9da7b76c | |||
| 3d261a5974 | |||
| e1cd8c7866 | |||
| 33c584364e | |||
| 2b6cfd34d1 | |||
| 4c42f73fda | |||
| 80243e0b8c | |||
| b0e09d7cd3 | |||
| 8aa95dbb88 | |||
| d656afbd2a | |||
| 32e3c38aef | |||
| 9004ed2d1b | |||
| 19f51596e8 | |||
| e8a8b69365 | |||
| 08a49953a0 | |||
| a424ca6cf9 | |||
| be692b48b0 | |||
| 12626bcfe4 | |||
| f02913c34e | |||
| 03e3bffaec | |||
| e5f3caf145 | |||
| 83ae20c740 | |||
| b0c09ed077 | |||
| ea69cc2849 | |||
| f3a3bfcbf2 | |||
| d65266a868 | |||
| 5b50a8faaf | |||
| 08101d9d0c | |||
| 755194a7bd | |||
| 53668799b2 | |||
| cc3c29a81a | |||
| 0837a2a00a | |||
| 477a677317 | |||
| b27c49e84a | |||
| e2b0789927 | |||
| 44dae8b90e | |||
| 2991ce18d3 | |||
| 1ebda1ccef | |||
| 9f68995de5 | |||
| 3a8c01a18b | |||
| dbdae514e0 | |||
| 21d0534167 | |||
| 323c8170bf | |||
| 82f5075946 | |||
| 06e337758d | |||
| 7369adcaca | |||
| 6c3044136b | |||
| e1976daacc | |||
| f7b19de32c | |||
| 4dbf5dbed2 | |||
| f93a69134e | |||
| 3f084f7f3c | |||
| b0296bf682 | |||
| 865be73a97 | |||
| 8d8cfdf375 | |||
| fb170439e8 | |||
| 4e5a8f6853 | |||
| 7192f4ab23 | |||
| 2049c6c5a2 | |||
| e22ba590cd | |||
| 19b4c5e065 | |||
| 06b21349bc | |||
| eee0cab26c | |||
| 36cbfcf483 | |||
| 1f2b590da6 | |||
| 8b2a0408bd | |||
| fbd116c0e5 | |||
| 5b283c872c | |||
| be60a0b272 | |||
| 56b46e2d13 | |||
| 52fb43f30f | |||
| 843adf0408 | |||
| e48c7618e4 | |||
| c5239d8312 | |||
| d6580c3dc0 | |||
| 81b06ee0e0 | |||
| dbfced05e7 | |||
| 2448bb56e6 | |||
| 637b159063 | |||
| 033d9efd2d | |||
| acc3ee18a1 | |||
| 5c447dd84f | |||
| 7d49e6c7e2 | |||
| a40e08e9d5 | |||
| 8e7d9f483d | |||
| 19f3cc33f1 | |||
| f9ece1b42c | |||
| 28cbacbf64 | |||
| 8f7d2789b8 | |||
| c4e3e122e2 | |||
| 629f4653c3 | |||
| ffa34e7075 | |||
| a8f2c80db0 | |||
| bbe579a9e3 | |||
| 47a3ebbea9 | |||
| 57e01e1a6b | |||
| 6e3df975a2 | |||
| 8825fbf1ef | |||
| 092f14db05 | |||
| 9385141f19 | |||
| b4b5b11070 | |||
| 139b93db61 | |||
| ca37d632c9 | |||
| 362abbf274 | |||
| 751eb9a885 | |||
| 2f589ffa76 | |||
| acba5beee5 | |||
| 74d1f3e63a | |||
| 8ac2edc810 | |||
| d4be5ab5d7 | |||
| c9591a694d | |||
| 5c756eb774 | |||
| 8236f30675 | |||
| b7508e3379 | |||
| f60786b536 | |||
| 30ec1a4649 | |||
| e1483d5fa0 | |||
| f4a0216601 | |||
| f188f9b709 | |||
| 9c9b51d35c | |||
| a75b4ac483 | |||
| e9e30c2304 | |||
| 4a1709e17e | |||
| bef1fbcbe6 | |||
| 2375a07d01 | |||
| 60c8251b72 | |||
| 10b850f9c7 | |||
| 99c4eebe3b | |||
| a759e85f5f | |||
| 56fc3df03b | |||
| eb01d5449d | |||
| 8098336d51 | |||
| b5d8a5d9cc | |||
| 6e60b9b17c | |||
| 1ab6cc7b68 | |||
| 5ae8133cfa | |||
| 39c6a83f23 | |||
| 1d7f2a207e | |||
| 557be3ab0e | |||
| c008b4aea8 | |||
| 922fb5108b | |||
| 7a7796afae | |||
| fb10fa5308 | |||
| 5e1a0a5adb | |||
| 757275f279 | |||
| fa8dfe631f | |||
| 112590114d | |||
| ff02da2667 | |||
| 4082fed85a | |||
| 5f13dcad78 | |||
| 61a38f83dc | |||
| ff61a49dd1 | |||
| 26986bbc60 | |||
| 7d8317a63e | |||
| 5cd735c48e | |||
| 67ae8e0603 | |||
| 14f69bddc8 | |||
| 90d3b0fb18 | |||
| e0aaa3c3b3 | |||
| 8783c41851 | |||
| 6407bcdf0a | |||
| a77b2c9cb8 | |||
| 34bbadd3ff | |||
| 88c0d7c726 | |||
| e01b9b5029 | |||
| 34fd98056b | |||
| 3a8f57a3c8 | |||
| 6673df0e48 | |||
| 7618e9bfd8 | |||
| a88c41cf8d | |||
| 27de343535 | |||
| 2a9fa23e06 | |||
| 2e56cfabee | |||
| 3930f709ce | |||
| 7e5ee8b7bf | |||
| 2d9a557427 | |||
| 4575443d44 | |||
| a0d787b746 | |||
| d20f3a9542 | |||
| 8e85580859 | |||
| 146d314057 | |||
| f679663224 | |||
| e066ced33b | |||
| 9b923dd4c4 | |||
| f6d42f2dd0 | |||
| 473a67073e | |||
| 87349d3496 | |||
| fde824af21 | |||
| 7dbf423763 | |||
| 6f47420213 | |||
| 4638250469 | |||
| 7859fe322a | |||
| d3e72719b4 | |||
| b4ab501767 | |||
| f079619f5e | |||
| 13f413493a | |||
| 6fbc0d3380 | |||
| b97404837e | |||
| e2953d47c5 | |||
| 19c4a4815e | |||
| fcfbd23e26 | |||
| b250faccd3 | |||
| 24c8b7d8a2 | |||
| 7c04f95415 | |||
| 6f8596ce3f | |||
| fe2f491dd7 | |||
| df02482f1d | |||
| 180c5629bf | |||
| e36912f961 | |||
| 9a83bd3381 | |||
| 54bebe417d | |||
| 43cfbe0086 | |||
| 4a68cf748e | |||
| d572cc1aab | |||
| 9b8166e3f0 | |||
| e2d439ee7e | |||
| 0435979f59 | |||
| 2ba1ef10be | |||
| 0964bdb64c | |||
| ecbd24566c | |||
| 660a05f581 | |||
| bc36122c3f | |||
| 15d9d31f1f | |||
| 1eef5c3cf1 | |||
| 87070b6d51 | |||
| 77549ae6c8 | |||
| 42290f5d1c | |||
| 209faf7b94 | |||
| 6116706c96 | |||
| 2670b973dd | |||
| af332d4aa9 | |||
| 86cae03cea | |||
| 29801e348a | |||
| 7e370c9637 | |||
| c4f6b8c6bc | |||
| a68e2f95f0 | |||
| a31b43b3f3 | |||
| f396cdd15c | |||
| 92ebbf1dc4 | |||
| 65688c2a87 | |||
| f303889ed9 | |||
| 9cdbe33570 | |||
| 95f673ecf7 | |||
| 91b8de8d32 | |||
| d8359c804b | |||
| 34bed24af3 | |||
| a101ac283f | |||
| 9fb38ac048 | |||
| 8f5c242426 | |||
| 3c995c7606 | |||
| ce8597dc14 | |||
| 2e10404d26 | |||
| 5ff5209ed5 | |||
| 5921043981 | |||
| add4ba622f | |||
| 277bd6e537 | |||
| 66d9cddc83 | |||
| d49bef88f9 | |||
| 8b42e751c6 | |||
| eb7f99d3dd | |||
| 764b840d6f | |||
| a1046d49c1 | |||
| 1cd994b4cf | |||
| 7bdba07310 | |||
| c54ede3a9e | |||
| ff6e733fe1 | |||
| 5989b7e1d7 | |||
| 1e64f153b3 | |||
| 78b30d3191 | |||
| 59de82688b | |||
| b85865d1ad | |||
| 3f2bb17722 | |||
| 38193d76e3 | |||
| 1d7772f218 | |||
| df81d847d7 | |||
| d6117ca362 | |||
| 9c0518608e | |||
| 9f1f37aa21 | |||
| 84213b0b8e | |||
| 8567b87d65 | |||
| c975e2ccbb | |||
| 3c90f6aea6 | |||
| 06eb90cc0d | |||
| 168ea8b0e1 | |||
| 012c62c748 | |||
| cc85b64cf6 | |||
| 1b4e24470a | |||
| 8c1bf9b784 | |||
| 7d0dd6706e | |||
| 9b47403b2d | |||
| 4db6a6140e | |||
| 3bf95e90c2 | |||
| 75fed7493e | |||
| 98b73fc95d | |||
| 4990e3686d | |||
| 4b7365388c | |||
| 0d8405588d | |||
| cb539dab78 | |||
| dadc881a96 | |||
| f3eea3a4d7 | |||
| cd37e82492 | |||
| 48a9ea223a | |||
| 7a458f00a6 | |||
| 97bff52e8c | |||
| 9f2e3faa69 | |||
| a821280dc7 | |||
| f73374a1eb | |||
| faab7536fc | |||
| fc9ebc645b | |||
| 2cc2c7ba1f | |||
| 50ceed7154 | |||
| e773429f7e | |||
| beae168f90 | |||
| f29d8f7ca9 | |||
| b1d3f9b2fd | |||
| b72cbf957d | |||
| ca23ff7924 | |||
| 1c3d400b14 | |||
| abafbf2afd | |||
| 536b20763e | |||
| 497b499d9d | |||
| e66bfcb1f8 | |||
| 1617685a77 | |||
| 25ebf15d02 | |||
| 5d05808072 | |||
| 0b8cacd6f1 | |||
| e7a61c761a | |||
| fb379eaa5b | |||
| 8a766804ad | |||
| 1eb6355182 | |||
| 04a9777b87 | |||
| e45e773436 | |||
| dae6b6893b | |||
| ba18ea9c32 | |||
| 9ab9110168 | |||
| e5d4669f16 | |||
| 94f01f19d5 | |||
| fa56763c25 | |||
| 25e26a6e51 | |||
| f248e9bdb4 | |||
| dceefe4f64 | |||
| c3881d097e | |||
| a29dfb1c63 | |||
| 0abaac84ea | |||
| 858c735856 | |||
| d6f58b2d14 | |||
| c4cf0dad82 | |||
| 57551902d0 | |||
| 1604ebaf10 | |||
| 6023038bae | |||
| ddd8f9cf41 | |||
| ec2b4fd85d | |||
| 86ce09aed1 | |||
| 21c1fa3849 | |||
| 8c339ac039 | |||
| e49f690fd7 | |||
| 96dad61a75 | |||
| cc2ea4c3fc | |||
| a0de301283 | |||
| 319a389f42 | |||
| 71def2f084 | |||
| 70f3ba57f5 | |||
| dd77fadc70 | |||
| be4578d517 | |||
| d7b499deff | |||
| 310ed81ac3 | |||
| 4c0d6e1eb4 | |||
| 167ac54c65 | |||
| 12f4108ac2 | |||
| dd571f0edb | |||
| 6d0d265047 | |||
| f11fa975a5 | |||
| 0e71d9b450 | |||
| eb0d4c9213 | |||
| bc45e2c023 | |||
| 095cbba57c | |||
| 8f1fe7a132 | |||
| 3ab1eacf09 | |||
| cd39c75e25 | |||
| b2e1e97cb1 | |||
| 96a11a1ef3 | |||
| e96f00586c | |||
| 3cfa5db2a2 | |||
| 1db6971a8d | |||
| b954127297 | |||
| d0d941efc7 | |||
| 8a951b2940 | |||
| 1e4703cbab | |||
| c3353add63 | |||
| ac8825b941 | |||
| 8fd94806e5 | |||
| d7c9cbf0b9 | |||
| c2ee13a0fe | |||
| f78994bb40 | |||
| dceabd4c5a | |||
| 86fa1dc30b | |||
| 288af365db | |||
| 0dc3ba60b3 | |||
| ec4f7e5194 | |||
| 4e666e1dfd | |||
| 3799e12f25 | |||
| fc3bc85db8 | |||
| 49c0a58d50 | |||
| 5fe09c2d67 | |||
| 6b69c79ac3 | |||
| 62e438f450 | |||
| 808c25337a | |||
| 6fc5008803 | |||
| a3bcc6981d | |||
| 3b28642801 | |||
| 538592dea4 | |||
| 2e07c4cc2f | |||
| 9ac255863f | |||
| 59e2aa505a | |||
| 4e8af93da1 | |||
| 6c2f8f2fb8 | |||
| 598e35401c | |||
| a01feb93d9 | |||
| d36f331b44 | |||
| 69abafb85a | |||
| 68a078fbbf | |||
| 10709dbb64 | |||
| 1227351079 | |||
| a77c658439 | |||
| 4516b833ce | |||
| 64dd1e1915 | |||
| 1ac4559d12 | |||
| e5d51840e8 | |||
| 6c29fe20ba | |||
| e3c56b0d6b | |||
| 4647c57243 | |||
| 856d4db3fb | |||
| 6a1064093f | |||
| c5f1ef4dff | |||
| 47ebfccbec | |||
| ad9486684f | |||
| 1d8372a8e2 | |||
| 9cb7d63424 | |||
| da2f110906 | |||
| b68113f5be | |||
| a68d7cd6f1 | |||
| 38e8b29f56 | |||
| ee7349c94f | |||
| 8cdd4293d4 | |||
| f58b843951 | |||
| 5fc142296f | |||
| 233d69aa6d | |||
| 9840d25269 | |||
| b878c96421 | |||
| 8f8a80cad5 | |||
| a8f6f8eb07 | |||
| f4b0a33633 | |||
| 7c783adf53 | |||
| 4000df9567 | |||
| bb35a3ba6f | |||
| 7ec3a87f22 | |||
| 0b74c8f473 | |||
| 83036ed646 | |||
| b7e43f5eb9 | |||
| 5c62d892fa | |||
| 41a31b404b | |||
| 7320aee17d | |||
| 2142a05d9d | |||
| c77a524459 | |||
| fac6680f31 | |||
| 08993707da | |||
| c805593ebe | |||
| 26556d7206 | |||
| 4839b6cb61 | |||
| d97214987a | |||
| b0bbc6d548 | |||
| 7074047a54 | |||
| 75a4737cfe | |||
| 8a3e4b8d02 | |||
| 6a6b4028bd | |||
| 92393b2676 | |||
| 50bf00e5f2 | |||
| 4cd004ead1 | |||
| 6c4539e372 | |||
| a3639ab1a0 | |||
| 169181f30f | |||
| 0f1056390d | |||
| 34a42e5620 | |||
| 8f09b82b12 | |||
| 200a5a5146 | |||
| 746b7b3247 | |||
| abdf16a4d9 | |||
| 0e13748649 | |||
| ccb697bac7 | |||
| e6bcdc60cf | |||
| 6615010cd0 | |||
| c2b80ad4e4 | |||
| 37a8f9e598 | |||
| c53f3339bb | |||
| 4dac7490e6 | |||
| fd7e058d0c | |||
| 1ab1027954 | |||
| 86931fef85 | |||
| e33d90b361 | |||
| 96dab34ad9 | |||
| 7c0cd26d13 | |||
| 45ecbc885b | |||
| 8aca98f9a7 | |||
| f4d9c8f755 | |||
| fb335f6a5f |
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG]"
|
||||
labels: "? - Needs Triage, bug"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: CUTLASS Discord
|
||||
url: https://discord.gg/nvidiadeveloper
|
||||
about: Come chat about using and contributing to CUTLASS!
|
||||
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
Normal file
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
---
|
||||
name: Documentation request
|
||||
about: Report incorrect or needed documentation to improve CUTLASS
|
||||
title: "[DOC]"
|
||||
labels: "? - Needs Triage, documentation"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Report incorrect documentation
|
||||
|
||||
**Location of incorrect documentation**
|
||||
Provide links and line numbers if applicable.
|
||||
|
||||
**Describe the problems or issues found in the documentation**
|
||||
A clear and concise description of what you found to be incorrect.
|
||||
|
||||
**Steps taken to verify documentation is incorrect**
|
||||
List any steps you have taken:
|
||||
|
||||
**Suggested fix for documentation**
|
||||
Detail proposed changes to fix the documentation if you have any.
|
||||
|
||||
---
|
||||
|
||||
## Report needed documentation
|
||||
|
||||
**Report needed documentation**
|
||||
A clear and concise description of what documentation you believe it is needed and why.
|
||||
|
||||
**Describe the documentation you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Steps taken to search for needed documentation**
|
||||
List any steps you have taken:
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for CUTLASS
|
||||
title: "[FEA]"
|
||||
labels: "? - Needs Triage, feature request"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
Normal file
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
---
|
||||
name: Submit question
|
||||
about: Ask a general question about CUTLASS
|
||||
title: "[QST]"
|
||||
labels: "? - Needs Triage, question"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What is your question?**
|
||||
112
.github/workflows/blossom-ci.yml
vendored
Normal file
112
.github/workflows/blossom-ci.yml
vendored
Normal file
@ -0,0 +1,112 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
# A workflow to trigger ci on hybrid infra (github + self hosted runner)
|
||||
name: Blossom-CI
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
platform:
|
||||
description: 'runs-on argument'
|
||||
required: false
|
||||
args:
|
||||
description: 'argument'
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
Authorization:
|
||||
name: Authorization
|
||||
runs-on: blossom
|
||||
outputs:
|
||||
args: ${{ env.args }}
|
||||
|
||||
# This job only runs for pull request comments
|
||||
if: |
|
||||
(startsWith(github.event.comment.body, '/bot run') ||
|
||||
startsWith(github.event.comment.body, '/bot kill')) && contains(
|
||||
fromJson('["zekunf-nv"]'),
|
||||
github.actor)
|
||||
steps:
|
||||
- name: Check if comment is issued by authorized person
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'AUTH'
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
|
||||
|
||||
Vulnerability-scan:
|
||||
name: Vulnerability scan
|
||||
needs: [Authorization]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}
|
||||
ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}
|
||||
lfs: 'true'
|
||||
|
||||
- name: Run blossom action
|
||||
uses: NVIDIA/blossom-action@main
|
||||
env:
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
|
||||
with:
|
||||
args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}
|
||||
args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}
|
||||
args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}
|
||||
|
||||
Job-trigger:
|
||||
name: Start ci job
|
||||
needs: [Vulnerability-scan]
|
||||
runs-on: blossom
|
||||
steps:
|
||||
- name: Start ci job
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'START-CI-JOB'
|
||||
CI_SERVER: ${{ secrets.CI_SERVER }}
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
Upload-Log:
|
||||
name: Upload log
|
||||
runs-on: blossom
|
||||
if : github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- name: Jenkins log for pull request ${{ fromJson(github.event.inputs.args).pr }} (click here)
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'POST-PROCESSING'
|
||||
CI_SERVER: ${{ secrets.CI_SERVER }}
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
11
.github/workflows/labeler.yml
vendored
Normal file
11
.github/workflows/labeler.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@main
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
35
.github/workflows/new-issues-to-triage-projects.yml
vendored
Normal file
35
.github/workflows/new-issues-to-triage-projects.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Auto Assign New Issues to Triage Project
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
assign_one_project:
|
||||
runs-on: ubuntu-latest
|
||||
name: Assign to New Issues to Triage Project
|
||||
steps:
|
||||
- name: Process bug issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'bug') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process feature issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'feature request') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process other issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, '? - Needs Triage') && (!contains(github.event.issue.labels.*.name, 'bug') && !contains(github.event.issue.labels.*.name, 'feature request'))
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
57
.github/workflows/stale.yml
vendored
Normal file
57
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
name: Mark inactive issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 * * * *"
|
||||
|
||||
jobs:
|
||||
mark-inactive-30d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 30 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
This issue will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-issue-label: "inactive-30d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
This PR will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-pr-label: "inactive-30d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 30
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
mark-inactive-90d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 90 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
stale-issue-label: "inactive-90d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 90
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
stale-pr-label: "inactive-90d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 90
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# PyCache files
|
||||
__pycache__/
|
||||
cutlass_library.egg-info/
|
||||
/build*
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +0,0 @@
|
||||
[submodule "tools/external/googletest"]
|
||||
path = tools/external/googletest
|
||||
url = https://github.com/google/googletest.git
|
||||
|
||||
657
CHANGELOG.md
657
CHANGELOG.md
@ -1,4 +1,616 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
# Changelog
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
|
||||
|
||||
### CuTe DSL
|
||||
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
|
||||
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
|
||||
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add variable sequence length support for FMHA Backward kernel.
|
||||
- Add varlen test support to Backward runner.
|
||||
- Codes support empty batch sequences.
|
||||
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
|
||||
* CuTe changes:
|
||||
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
|
||||
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
|
||||
- Factor out `print_latex` and friends and rewrite.
|
||||
- Factor out `print_svg` and friends and rewrite.
|
||||
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
|
||||
* Support residual add for implicit gemm kernels.
|
||||
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
|
||||
- Add verifier for sm90 to report the invalid input.
|
||||
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
|
||||
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
|
||||
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
|
||||
* Fix profiler bugs in exhaustive perf search.
|
||||
- Fix incorrect cluster shape output issue when doing exhaustive search.
|
||||
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
|
||||
* Fix some profiler issues.
|
||||
- Complete the reference for Blackwell blockwise gemm kernels.
|
||||
- Fix incorrect regex logic for L1 test.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||
|
||||
### CuTe DSL
|
||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass)
|
||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
|
||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||
- For example:
|
||||
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
|
||||
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
|
||||
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
|
||||
* Fix profiler issues which cause no output or not supported error for some kernels.
|
||||
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
|
||||
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
|
||||
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* CuTe changes:
|
||||
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
|
||||
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
|
||||
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
|
||||
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
|
||||
# CUTLASS 3.x
|
||||
|
||||
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
||||
* Fixed [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128.
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30)
|
||||
* Fixed Group Gemm hang issue in CUTLASS 3.x
|
||||
* Improved Hopper [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance.
|
||||
|
||||
## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24)
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](https://github.com/NVIDIA/cutlass/tree/main/tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp), [convolution](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/dispatch_policy.hpp), and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and full set of EVT fusions.
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
- Support for mixed input GEMM kernels on Hopper in the profiler.
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](https://github.com/NVIDIA/cutlass/tree/main/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](https://github.com/NVIDIA/cutlass/tree/main/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
|
||||
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
|
||||
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03)
|
||||
|
||||
- [Hopper structured sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
|
||||
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25)
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#gemm), and
|
||||
[example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
+ [FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
|
||||
+ [int8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [int4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [FP32 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
|
||||
- [CUDA host adapter](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
|
||||
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py).
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
||||
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
|
||||
|
||||
- Statically available [CUTLASS Version macros](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
* Expanded [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved CuTe documentation including improved clarity and depth of [Quickstart](./media/docs/cpp/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
|
||||
* [Copy Async based Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
|
||||
* Profiler support for lower-aligned Hopper GEMMs.
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](https://github.com/NVIDIA/cutlass/tree/main/examples/52_hopper_gather_scatter_fusion).
|
||||
* Sub-Byte type fixes and improvements.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
|
||||
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
|
||||
|
||||
## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25)
|
||||
* Minor patch for issue/1138
|
||||
|
||||
## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
|
||||
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
|
||||
* SM80 EVT support in C++ and Python.
|
||||
* Other SM90 epilogue improvements.
|
||||
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) for details.
|
||||
* SM90 TF32 kernel improvements for all layouts.
|
||||
* SM90 rasterization direction support in the CUTLASS profiler.
|
||||
* Improvement for CUTLASS profiler build times.
|
||||
* Remove Python-C++ bindings.
|
||||
|
||||
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
|
||||
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](https://github.com/NVIDIA/cutlass/tree/main/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](https://github.com/NVIDIA/cutlass/tree/main/examples/python/03_basic_conv2d.ipynb) here.
|
||||
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
|
||||
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
|
||||
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](https://github.com/NVIDIA/cutlass/tree/main/examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#target-architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
* [CUTLASS library integration](https://github.com/NVIDIA/cutlass/tree/main/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm), [49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_schedules_with_collective_builder), and [50](https://github.com/NVIDIA/cutlass/tree/main/examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
|
||||
* [kFixedStrideDilation](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](https://github.com/NVIDIA/cutlass/tree/main/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float8.h) and [conversion routines](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](https://github.com/NVIDIA/cutlass/tree/main/examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](https://github.com/NVIDIA/cutlass/tree/main/examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
|
||||
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
|
||||
* [Depthwise separable convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](https://github.com/NVIDIA/cutlass/tree/main/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* It can select random rows in a row major matrix.
|
||||
* It can select random columns in a column major matrix.
|
||||
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* Supported kernels: GEMM and CONV.
|
||||
* Supported types: fp16 and int8.
|
||||
* Supported architectures: Turing and Ampere.
|
||||
* [Transposed Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
|
||||
* Epilogue enhancement:
|
||||
* Eliminate bank conflicts in int8 tensor core kernels.
|
||||
* Half2 usage if epilogue compute type is fp16.
|
||||
* More activation functions: Silu, Hardswish, Leaky Relu.
|
||||
* New elementwise fusion pattern for [residual block](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](https://github.com/NVIDIA/cutlass/tree/main/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
|
||||
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
|
||||
* Tuning for GEMMs fused with partial reductions
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory.h) and [global load](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](https://github.com/NVIDIA/cutlass/tree/main/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) and [functional.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* Custom reduction functors such as `cutlass::logical_and`
|
||||
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
|
||||
* Optimizations for 3-D convolution
|
||||
* [Optimized tile iterators](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/README.md)
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
|
||||
## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19)
|
||||
* Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs
|
||||
* Operators: forward (Fprop), backward data gradient (Dgrad), and backward weight gradient (Wgrad) convolution
|
||||
* Data type: FP32, complex<FP32>, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32
|
||||
* Spatial dimensions: 1-D, 2-D, and 3-D
|
||||
* Layout: NHWC, NCxHWx
|
||||
* Implicit GEMM convolution components:
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* [Sparse Tensor Core GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h) such as [GeLU](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/matrix.h) and [quaternion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](https://github.com/NVIDIA/cutlass/tree/main/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Tensor Float 32, BFloat16, and double-precision data types
|
||||
* Mixed integer data types (int8, int4, bin1)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
|
||||
* Features:
|
||||
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
|
||||
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
|
||||
* Gaussian complex GEMMs using 3m complex multiply algorithm
|
||||
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
|
||||
* Policy updates:
|
||||
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
* Substantially refactored for
|
||||
* Better performance, particularly for native Turing Tensor Cores
|
||||
* Robust and durable templates spanning the design space
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](./media/docs/cpp/quickstart.md)
|
||||
* [Documentation](./README.md#documentation)
|
||||
* [Utilities](./media/docs/cpp/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/cpp/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
* Coverage of existing CUTLASS functionality
|
||||
* GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
|
||||
* Volta Tensor Cores through native mma.sync and through WMMA API
|
||||
* Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
|
||||
* Batched GEMM operations
|
||||
* Complex-valued GEMMs
|
||||
* **Note: a host compiler supporting C++11 or greater is required.**
|
||||
|
||||
# CUTLASS 1.x
|
||||
|
||||
## [1.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.2) (2019-07-09)
|
||||
* Performance improvement for Volta Tensor Cores TN and TT layouts.
|
||||
@ -50,27 +662,32 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
112
CITATION.cff
Normal file
112
CITATION.cff
Normal file
@ -0,0 +1,112 @@
|
||||
cff-version: 1.2.0
|
||||
title: CUTLASS
|
||||
message: >-
|
||||
If you use this software, please cite using the
|
||||
following metadata.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Vijay
|
||||
family-names: Thakkar
|
||||
email: vithakkar@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramani
|
||||
email: prramani@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Cris
|
||||
family-names: Cecka
|
||||
email: ccecka@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Honghao
|
||||
family-names: Lu
|
||||
email: honghaol@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Ethan
|
||||
family-names: Yan
|
||||
email: etyan@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Jack
|
||||
family-names: Kosaian
|
||||
email: jkosaian@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Mark
|
||||
family-names: Hoemmen
|
||||
email: mhoemmen@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Haicheng
|
||||
family-names: Wu
|
||||
email: haichengw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Andrew
|
||||
family-names: Kerr
|
||||
email: akerr@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Duane
|
||||
family-names: Merrill
|
||||
email: dumerrill@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Fengqi
|
||||
family-names: Qiao
|
||||
email: fqiao@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Piotr
|
||||
family-names: Majcher
|
||||
email: pmajcher@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Paul
|
||||
family-names: Springer
|
||||
email: pspringer@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Markus
|
||||
family-names: Hohnerbach
|
||||
affiliation: NVIDIA
|
||||
email: mhohnerbach@nvidia.com
|
||||
- given-names: Jin
|
||||
family-names: Wang
|
||||
email: jinw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
|
||||
|
||||
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||
abstract: >-
|
||||
CUTLASS is a collection of CUDA C++ template
|
||||
abstractions for implementing high-performance
|
||||
matrix-multiplication (GEMM) and related
|
||||
computations at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical
|
||||
decomposition and data movement similar to those
|
||||
used to implement cuBLAS and cuDNN. CUTLASS
|
||||
decomposes these "moving parts" into reusable,
|
||||
modular software components abstracted by C++
|
||||
template classes. These thread-wide, warp-wide,
|
||||
block-wide, and device-wide primitives can be
|
||||
specialized and tuned via custom tiling sizes, data
|
||||
types, and other algorithmic policy. The resulting
|
||||
flexibility simplifies their use as building blocks
|
||||
within custom kernels and applications.
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda, cute, nvidia, gpu, linear algebra, matrix computations'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v3.0.0/LICENSE.txt
|
||||
version: '3.0.0'
|
||||
date-released: '2023-01-23'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v3.0.0"
|
||||
description: The GitHub release URL of tag 3.0.0
|
||||
@ -1,26 +0,0 @@
|
||||
# A small utility function which generates a C-header from an input file
|
||||
function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
FILE(READ "${FILENAME}" HEX_INPUT HEX)
|
||||
if (${ZERO_TERMINATED})
|
||||
string(APPEND HEX_INPUT "00")
|
||||
endif()
|
||||
|
||||
string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
|
||||
|
||||
set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")
|
||||
|
||||
set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message("Create header file for ${FILE_IN}")
|
||||
message("Create header file for ${FILE_OUT}")
|
||||
file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
|
||||
|
||||
set(RESULT "#pragma once\n")
|
||||
string(APPEND RESULT "namespace cutlass {\n")
|
||||
string(APPEND RESULT "namespace nvrtc {\n")
|
||||
string(APPEND RESULT "${OUTPUT_STRING}")
|
||||
string(APPEND RESULT "} // namespace nvrtc\n")
|
||||
string(APPEND RESULT "} // namespace cutlass\n")
|
||||
file(WRITE "${FILE_OUT}" "${RESULT}")
|
||||
1227
CMakeLists.txt
Normal file → Executable file
1227
CMakeLists.txt
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
203
CONTRIBUTORS.md
Normal file
203
CONTRIBUTORS.md
Normal file
@ -0,0 +1,203 @@
|
||||

|
||||
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS C++ Developers **
|
||||
|
||||
Andrew Kerr<br />
|
||||
Paul Springer<br />
|
||||
Dustyn Blasig<br />
|
||||
Albert Xu<br />
|
||||
Junkai Wu<br />
|
||||
Xiuxia Zhang<br />
|
||||
Haicheng Wu<br />
|
||||
Jack Yang<br />
|
||||
Pradeep Ramani<br />
|
||||
Aditya Atluri<br />
|
||||
Han Li<br />
|
||||
Nick Zhao<br />
|
||||
Ivan Yin<br />
|
||||
Yu-Jung Chen<br />
|
||||
Markus Hoehnerbach<br />
|
||||
Honghao Lu<br />
|
||||
Mihir Awatramani<br />
|
||||
Hao Sheng<br />
|
||||
Zekun Fan<br />
|
||||
Aniket Shivam<br />
|
||||
Siyu Liu<br />
|
||||
Richard Cai<br />
|
||||
Vikas Gupta<br />
|
||||
Ethan Yan<br />
|
||||
Vijay Thakkar<br />
|
||||
Cris Cecka<br />
|
||||
Lawrence Ryan<br />
|
||||
Qun Song<br />
|
||||
Daniel Ricketts<br />
|
||||
dePaul Miller<br />
|
||||
Yuhan Li<br />
|
||||
Saman Ashkiani<br />
|
||||
Jack Chen<br />
|
||||
Shang Zhang<br />
|
||||
Petrick Liu<br />
|
||||
Questa Wang<br />
|
||||
Pramod Shenoy<br />
|
||||
Jack Kosaian<br />
|
||||
Yujia Zhai<br />
|
||||
Zhaodong Chen<br />
|
||||
Manas Sahni<br />
|
||||
Shunfan Shao<br />
|
||||
Fengqi Qiao<br />
|
||||
Serif Yesil<br />
|
||||
Aragorn Guan<br />
|
||||
Heidi He<br />
|
||||
Xiao Song<br />
|
||||
Sergey Klevtsov<br />
|
||||
Jiang Shao<br />
|
||||
Ruqing Xu<br />
|
||||
Mengyu Guo<br />
|
||||
Tao Xie<br />
|
||||
Linfeng Zheng<br />
|
||||
Harrison Barclay<br />
|
||||
Wenfei Tang<br />
|
||||
Diksha Gohlyan<br />
|
||||
Alexander Zhurkevich<br />
|
||||
Siyuan Fu<br />
|
||||
Hua Huang<br />
|
||||
Xiufan Liang<br />
|
||||
Ian Tramble<br />
|
||||
Ali Hassani<br />
|
||||
Shreya Gaur<br />
|
||||
|
||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||
|
||||
# CUTLASS DSL Developers ***
|
||||
|
||||
Albert Di<br />
|
||||
Albert Xu<br />
|
||||
Anakin Zheng<br />
|
||||
Arvin Jou<br />
|
||||
Brandon Sun<br />
|
||||
Chenyang Xu<br />
|
||||
Chunyu Wang<br />
|
||||
Cris Cecka<br />
|
||||
dePaul Miller<br />
|
||||
Edward Cao<br />
|
||||
Fung Xie<br />
|
||||
Guray Ozen<br />
|
||||
Hao Hu<br />
|
||||
Hong Wang<br />
|
||||
Jeremy Furtek<br />
|
||||
Jie Fang <br />
|
||||
JingZe Cui<br />
|
||||
Kihiro Bando<br />
|
||||
Linfeng Zheng<br />
|
||||
Longsheng Du<br />
|
||||
Mina Sun<br />
|
||||
Mindy Li<br />
|
||||
Pradeep Ramani<br />
|
||||
Questa Wang<br />
|
||||
Serif Yesil<br />
|
||||
Tao Xie<br />
|
||||
Tina Li<br />
|
||||
Vicki Wang<br />
|
||||
Vincent Zhang<br />
|
||||
Vijay Thakkar<br />
|
||||
Xiao Dong<br />
|
||||
Xiaolei Shi<br />
|
||||
Xinyu Wang<br />
|
||||
Yihan Chen<br />
|
||||
Yuhan Li<br />
|
||||
Zekun Fan<br />
|
||||
|
||||
*** _Sorted in alphabetical order._
|
||||
|
||||
|
||||
# CuTe Developers
|
||||
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
|
||||
|
||||
# CUTLASS Product Manager
|
||||
|
||||
Matthew Nicely<br />
|
||||
|
||||
|
||||
# Former CUTLASS Developers
|
||||
|
||||
Manish Gupta<br />
|
||||
Duane Merrill<br />
|
||||
Piotr Majcher<br />
|
||||
Naila Farooqui<br />
|
||||
Mark Hoemmen<br />
|
||||
Rawn Henry<br />
|
||||
Jin Wang<br />
|
||||
Timmy Liu<br />
|
||||
Manikandan Ananth<br />
|
||||
David Tanner<br />
|
||||
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
Tri Dao<br />
|
||||
Jay Shah<br />
|
||||
Mehdi Amini<br />
|
||||
Larry Wu<br />
|
||||
Justin Holewinski<br />
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
Michael Garland<br />
|
||||
Michael Goldfarb<br />
|
||||
Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Wei Liu<br />
|
||||
Tim Martin<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Yang Xu<br />
|
||||
Scott Yokim<br />
|
||||
Girish Bharambe<br />
|
||||
Luke Durant<br />
|
||||
Carter Edwards<br />
|
||||
Olivier Giroux<br />
|
||||
Stephen Jones<br />
|
||||
Rishkul Kulkarni<br />
|
||||
Bryce Lelbach<br />
|
||||
Joel McCormack<br />
|
||||
Kyrylo Perelygin<br />
|
||||
Sean Treichler<br />
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
369
CUDA.cmake
Normal file
369
CUDA.cmake
Normal file
@ -0,0 +1,369 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.")
|
||||
set(__CLANG_DEVICE_COMPILATION_REQUESTED ON)
|
||||
elseif(CUDA_COMPILER)
|
||||
message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.")
|
||||
endif()
|
||||
|
||||
if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER)
|
||||
set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out
|
||||
endif()
|
||||
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
# For backward compatibility with older CMake code.
|
||||
set(CUDA_VERSION ${CUDAToolkit_VERSION})
|
||||
set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR})
|
||||
set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
# In some scenarios, such as clang device compilation, the toolkit root may not be set, so we
|
||||
# force it here to the nvcc we found via the CUDAToolkit package.
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])")
|
||||
set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation")
|
||||
elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang")
|
||||
set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation")
|
||||
else()
|
||||
message(FATAL_ERROR "Unknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
|
||||
endif()
|
||||
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30")
|
||||
message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.")
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 9.2)
|
||||
message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
CUDART_LIBRARY cudart
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET cudart AND CUDART_LIBRARY)
|
||||
|
||||
message(STATUS "CUDART: ${CUDART_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(cudart STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cudart SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudart ALIAS cudart)
|
||||
|
||||
set_property(
|
||||
TARGET cudart
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDART_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cudart)
|
||||
|
||||
message(STATUS "CUDART: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDART: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
CUDA_DRIVER_LIBRARY cuda
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
lib64/stubs
|
||||
lib/stubs
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
|
||||
|
||||
message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(cuda_driver STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cuda_driver SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cuda_driver ALIAS cuda_driver)
|
||||
|
||||
set_property(
|
||||
TARGET cuda_driver
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDA_DRIVER_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cuda_driver)
|
||||
|
||||
message(STATUS "CUDA Driver: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDA Driver: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NVRTC_LIBRARY nvrtc
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
|
||||
|
||||
message(STATUS "NVRTC: ${NVRTC_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(nvrtc STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(nvrtc SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::nvrtc ALIAS nvrtc)
|
||||
|
||||
set_property(
|
||||
TARGET nvrtc
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${NVRTC_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET nvrtc)
|
||||
|
||||
message(STATUS "NVRTC: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "NVRTC: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
|
||||
|
||||
if (MSVC)
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs BATCH_SOURCES BATCH_SIZE)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED TARGET_ARGS_VAR)
|
||||
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED __BATCH_SOURCES)
|
||||
set(__BATCH_SOURCES ON)
|
||||
endif()
|
||||
|
||||
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
|
||||
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
foreach(ARG ${__UNPARSED_ARGUMENTS})
|
||||
if(${ARG} MATCHES ".*\.cu$")
|
||||
list(APPEND CUDA_FILE_ARGS ${ARG})
|
||||
else()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${ARG})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
while(NUM_CUDA_FILE_ARGS GREATER 0)
|
||||
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
|
||||
string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}")
|
||||
string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH)
|
||||
set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu)
|
||||
message(STATUS "Generating ${BATCH_FILE}")
|
||||
file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n")
|
||||
foreach(CUDA_FILE ${CUDA_FILE_BATCH})
|
||||
get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE)
|
||||
file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n")
|
||||
endforeach()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE})
|
||||
if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE)
|
||||
break()
|
||||
endif()
|
||||
list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS)
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
endwhile()
|
||||
|
||||
else()
|
||||
|
||||
set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
endif()
|
||||
|
||||
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
function(cutlass_add_library NAME)
|
||||
|
||||
set(options SKIP_GENCODE_FLAGS)
|
||||
set(oneValueArgs EXPORT_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
|
||||
if (NOT __SKIP_GENCODE_FLAGS)
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
endif()
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
get_target_property(TARGET_TYPE ${NAME} TYPE)
|
||||
|
||||
if (TARGET_TYPE MATCHES "SHARED")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Shared)
|
||||
elseif(TARGET_TYPE MATCHES "STATIC")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Static)
|
||||
endif()
|
||||
|
||||
if(__EXPORT_NAME)
|
||||
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
|
||||
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs CUDA_RUNTIME_LIBRARY)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __CUDA_RUNTIME_LIBRARY)
|
||||
set(__CUDA_RUNTIME_LIBRARY Shared)
|
||||
endif()
|
||||
|
||||
set(__CUDA_RUNTIME_LIBRARY_ALLOWED None Shared Static)
|
||||
if (NOT __CUDA_RUNTIME_LIBRARY IN_LIST __CUDA_RUNTIME_LIBRARY_ALLOWED)
|
||||
message(FATAL_ERROR "CUDA_RUNTIME_LIBRARY value '${__CUDA_RUNTIME_LIBRARY}' is not in allowed list of '${__CUDA_RUNTIME_LIBRARY_ALLOWED}'")
|
||||
endif()
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY ${__CUDA_RUNTIME_LIBRARY})
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
endfunction()
|
||||
378
CUTLASS.md
378
CUTLASS.md
@ -1,378 +0,0 @@
|
||||

|
||||
|
||||
# CUTLASS
|
||||
|
||||
This document is intended to accompany the CUTLASS source code, to describe the interaction between
|
||||
CUTLASS core components, and to identify their role in implementing GEMM computations efficiently in CUDA.
|
||||
|
||||
1. [Design Patterns](#S-design-patterns)
|
||||
2. [General Matrix Multiply](#S-general-matrix-multiply)
|
||||
3. [Core Components](#S-core-components)
|
||||
4. [Utilities](#S-utilities)
|
||||
5. [Optimization Strategies](#S-optimization-strategies)
|
||||
|
||||
# <a name="S-design-patterns"></a> 1. Design Patterns
|
||||
|
||||
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
|
||||
flexible composition that can be easily applied to solve new problems related to Deep Learning and
|
||||
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
|
||||
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
|
||||
design patterns are necessary to yield a composable structure while also satisfying these performance
|
||||
objectives. This section is intended to provide more detail.
|
||||
|
||||
* [Sequencing and Nesting](#S-patterns-sequencing-nesting)
|
||||
* [Tiles and Iterators](#S-patterns-tiles-iterators)
|
||||
* [Host-side Params](#S-patterns-host-side-params)
|
||||
* [Composable Shared Memory](#S-patterns-composable-shared-memory)
|
||||
|
||||
## <a name="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives
|
||||
|
||||
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
|
||||
|
||||
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
|
||||
|
||||
Efficient dense linear algebra computations emphasize data movement to match the execution of mathematical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
|
||||
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
|
||||
|
||||
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
|
||||
elements in memory as well as traversing over a collection. GEMM kernels in CUTLASS depend on accessing
|
||||
a sequence of tiles from global memory, from shared memory, and in registers. Consequently, _tile iterators_
|
||||
are prevalent throughout the CUTLASS implementation.
|
||||
|
||||
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
|
||||
|
||||
## <a name="S-patterns-host-side-params"></a> Host-side Params structure
|
||||
|
||||
Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer.
|
||||
|
||||
CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread.
|
||||
|
||||
The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
|
||||
|
||||
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
|
||||
|
||||
|
||||
## <a name="S-patterns-composable-shared-memory"></a> Composable shared memory allocation
|
||||
|
||||
Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held in shared memory. Any object requiring shared memory storage for itself or its data members should define a child structure called SharedStorage. This holds data needed by the class and also instantiates SharedStorage objects for each data member.
|
||||
|
||||
To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall SMEM capacity.
|
||||
|
||||
## <a name="S-patterns-loop-unrolling"></a> Loop Unrolling
|
||||
|
||||
CUTLASS requires tiles of data to be stored in registers for high-bandwidth access. Simultaneously, high-throughput math instructions
|
||||
must be issued concurrently with memory instructions to hide latency with relatively few concurrent threads. These objectives are
|
||||
achieved by unrolling loops whose iteration counts are known at compile time.
|
||||
|
||||
Consequently, most loops within the CUTLASS GEMM implementation are specified by constant values and template arguments. The CUDA compiler
|
||||
is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule.
|
||||
|
||||
## <a name="S-patterns-loop-unrolling"></a> Templates
|
||||
|
||||
CUDA C++ templates and modern generic programming techniques enable CUTLASS device code to span a large design space.
|
||||
|
||||
This design space includes:
|
||||
* Mixed precision arithmetic and data storage
|
||||
* Kernels specialized for layout and problem size
|
||||
* Support for kernel fusion
|
||||
|
||||
Moreover, templates provided a structured approach to collecting compile-time constants such as tile dimensions. These
|
||||
must be template arguments to target static array allocation and take advantage of loop unrolling, constant folding,
|
||||
and function inlining.
|
||||
|
||||
# <a name="S-general-matrix-multiply"></a> 2. General Matrix Multiply
|
||||
|
||||
The following figure illustrates the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right.
|
||||
|
||||

|
||||
|
||||
## Threadblock-level GEMM
|
||||
|
||||
The CUTLASS GEMM kernel partitions the _C_ matrix into a 2D tiling of threadblocks.
|
||||
Each threadblock computes a matrix product whose outer dimensions _M_ and _N_ are compile-time constants. The
|
||||
GEMM's _K_ dimension is partitioned into tiles and iterated over by the GEMM _mainloop_. The shape of the matrix
|
||||
multiply operation performed by each iteration of the mainloop is referred to as _OutputTile_.
|
||||
|
||||
The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
|
||||
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
|
||||
buffer in shared memory is performed by a _GlobalLoadIterator_.
|
||||
|
||||
**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
|
||||
|
||||
The Global Load Stream template contains members defined by the following templates:
|
||||
|
||||
* [GemmGlobalIteratorAb](cutlass/gemm/gemm_global_tile.h)
|
||||
* [Transformer](cutlass/convert.h)
|
||||
* [GemmSharedStoreTileAb](cutlass/gemm/gemm_shared_tile.h)
|
||||
|
||||
## Warp-level GEMM
|
||||
|
||||
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
|
||||
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
|
||||
|
||||
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
|
||||
|
||||
* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)
|
||||
|
||||
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
|
||||
|
||||
* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)
|
||||
|
||||
## Thread-level GEMM
|
||||
|
||||
SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
|
||||
procedures.
|
||||
|
||||
* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
|
||||
* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
|
||||
* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)
|
||||
|
||||
## Epilogue
|
||||
|
||||
The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
|
||||
|
||||
1. [Transformer](cutlass/convert.h) for converting the data types of accumulator elements
|
||||
2. [GemmSharedStoreTileD](cutlass/gemm/gemm_shared_tile.h) to store to shared memory specialized to the accumulator layout.
|
||||
3. [GemmSharedLoadTileD](cutlass/gemm/gemm_shared_tile.h) to load the data from shared memory.
|
||||
4. [GemmGlobalIteratorC](cutlass/gemm/gemm_global_tile.h) to load a tile from global memory.
|
||||
5. A [functor](cutlass/gemm/linear_scaling.h) to compute an element-wise operation on the matrix product and source data (such as alpha*AB+beta*C).
|
||||
6. [GemmGlobalIteratorD](cutlass/gemm/gemm_global_tile.h) to write the output to global memory.
|
||||
|
||||
## GEMM Traits
|
||||
|
||||
[**cutlass::gemm::GemmTraits**](cutlass/gemm/gemm_traits.h) collects the structural properties of a complete GEMM computation into a single template class. As a result, the Traits classes encapsulate the the iterators and transformers for all supported GEMM operands and layouts. Low-level details needed by Traits (such as scalar types for operands, thread-block tile size, number of scalar elements per memory access within each phase, number of stages in shared memory, as well as other implementation-specific properties of the GEMM computation) are specified in class [**cutlass::gemm::GemmConfig**](cutlass/gemm/gemm_config.h).
|
||||
|
||||
|
||||
# <a name="S-core-components"></a> 3. Core Components
|
||||
|
||||
CUTLASS GEMM kernels are implemented by a set of Core components for interacting with mathematical tensor and matrix
|
||||
objects as well as constructing efficient CUDA kernels.
|
||||
|
||||
* [Tensor views](#S-core-tensor-views)
|
||||
* [Shape](#S-core-shape)
|
||||
* [Tile structure](#S-core-tile-structure)
|
||||
* [Fragment](#S-core-fragment)
|
||||
* [Predicate vector](#S-core-predicate-vector)
|
||||
|
||||
## <a name="S-core-tensor-views"></a> Tensor View
|
||||
|
||||
Matrices and tensors are typically represented as n-D arrays held in linear memory with a single base pointer and a stride vector. Element _i_ of the stride vector indicates the offset in linear memory between consecutive elements in dimension i. Consequently, the linear offset for an arbitrary element specified as an n-tuple may be computed as the dot product of the coordinate and the stride vector.
|
||||
|
||||
CUTLASS provides abstractions for interacting with multidimension tensors in device memory.
|
||||
Consequently, we define a hierarchy of pointer-like types for referencing tensors.
|
||||
|
||||
`T *` - raw pointer to elements of type T
|
||||
|
||||
`cutlass::TensorRef<T, Rank>` - reference to a tensor of elements of type T and given rank. Includes a mapping function and associated stride vector for accessing elements in linear memory.
|
||||
|
||||
`cutlass::TensorView<T, Rank>` - extends `TensorRef<>` by adding bounds information. This is a complete mathematical object which may be used as the argument to CUTLASS functions.
|
||||
|
||||
The above provide an identity maping of a logical index space to linear memory. An element
|
||||
at logical coordinate X has an offset computed as follows:
|
||||
```
|
||||
offset = dot(X, stride)
|
||||
```
|
||||
where `dot()` computes the inner product of X and a vector of "strides."
|
||||
|
||||
CUTLASS 1.1 introduces a mapping function and an additional "storage rank" to offer a flexible way to
|
||||
map the logical index space of the tensor to memory. The mapping function maps a coordinate
|
||||
of rank _R_ to an index space of rank _S_. The linear offset is computed as:
|
||||
```
|
||||
offset = dot( MapFunc(X), stride )
|
||||
```
|
||||
where stride is a vector of rank _S_.
|
||||
|
||||
CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and
|
||||
correctness. Consequently, we enforce a constraint on the strides used by mapping functions
|
||||
such that:
|
||||
|
||||
1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in
|
||||
that rank are consecutive in linear memory.
|
||||
|
||||
2. The fastest changing rank is always last in the stride vector and not explicitly stored.
|
||||
|
||||
Thus, the stride vector used by mapping functions has length of one fewer than the rank of the
|
||||
storage tensor. These constraints are consistent with the BLAS interface of passing matrices as
|
||||
a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors
|
||||
whose fastest changing dimension is 1, and only the strided dimension is explicitly represented.
|
||||
|
||||
A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor,
|
||||
to linear memory such that (1.) elements in the same column are consecutive in memory
|
||||
(column-major), or (2.) elements in the same row are consecutive (row-major). These can be
|
||||
accomplished by two different mapping functions whose stride vector is length=2. The first
|
||||
element is the "leading dimension."
|
||||
|
||||
The requirement that the fastest-changing stride always be of unit size need not be a limitation.
|
||||
To implement "sparse" computations or matrix operations in which matrix elements have arbitrary
|
||||
stride along both row and column, define a mapping function whose storage rank is 3. This permits
|
||||
two elements of the stride vector to have a non-unit value.
|
||||
|
||||
`cutlass::TensorView<>` extends this concept by including a size vector to specify the bounds of
|
||||
the index space. The value of each coordinate in the size vector defines the half-open range of
|
||||
indices whose smallest value is zero.
|
||||
|
||||
## <a name="S-core-shape"></a> Shape
|
||||
|
||||
To avoid complicated template metaprogramming, CUTLASS targets fixed compile-time tile sizes specified
|
||||
by a four-dimensional template `cutlass::Shape<>`. This defines the following dimensions, mirroring
|
||||
the NHWC tensor format used for convolution in Deep Learning frameworks.
|
||||
|
||||
- `D`: depth of tensor
|
||||
- `H`: first strided dimension
|
||||
- `W`: contiguous sequence of tensor elements
|
||||
- `C`: number of channels, usually used for vectorized access
|
||||
|
||||
Template specializations of `Shape` appear as arguments to numerous dependent template classes which
|
||||
must specify compile-time constant tile sizes.
|
||||
|
||||
## <a name="S-core-tile-structure"></a> Tile Structure
|
||||
|
||||
Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
|
||||
threads to the problem space. For example, the CUTLASS GEMM
|
||||
|
||||
Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
|
||||
members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
|
||||
lattice, partition its elements among a collection of threads, and map each unique thread ID to a unique
|
||||
offset.
|
||||
|
||||
- _Tile_ (concept `Shape<>`) - describes the dimensions of the tile in terms of scalar elements
|
||||
- _Delta_ (concept `Shape<>`) - describes the distance along each logical dimension between items
|
||||
- _Iterations_ (concept `Shape<>`) - describes the number of items along each logical dimension
|
||||
- _ThreadOffset_ (concept _functor_) - implements `Coord<4> operator()() const` to determine a thread's
|
||||
initial offset in the logical 4-D coordinate space
|
||||
|
||||
The following figure illustrates the CUTLASS tile structure. The overall shape, 16-by-16, is partitioned into
|
||||
vectors of length two among 32 threads. The elements stored by thread 9 are highlighted.
|
||||
|
||||
<img src="/media/images/cutlass-tile-structure.png" alt="CUTLASS tile structure" width="30%" />
|
||||
|
||||
The `cutlass::TileTraits<>` definition that describes this arrangement may be defined as follows:
|
||||
|
||||
```
|
||||
struct ExampleTileTraits {
|
||||
|
||||
/// Overall shape of tile
|
||||
typedef Shape<1, 16, 16, 1> Tile;
|
||||
|
||||
/// Distance along each dimension of accesses
|
||||
typedef Shape<1, 4, 1, 1> Delta;
|
||||
|
||||
/// Number of memory accesses performed by each thread
|
||||
typedef Shape<1, 4, 1, 1> Iterations;
|
||||
|
||||
/// Offset function - maps each thread to a unique starting offset within the 4D tile
|
||||
struct ThreadOffset {
|
||||
|
||||
CUTLASS_DEVICE Coord<4> operator()() const {
|
||||
|
||||
typdef Shape<1, 16, 8, 2> Vectorized;
|
||||
|
||||
return make_Coord(
|
||||
0, // depth "D" dimension
|
||||
threadIdx.x / Vectorized::kW, // horisontal "H" dimension - first strided dimension
|
||||
threadIdx.x % Vectorized::kW, // vertical "W" dimension - contiguous dimension
|
||||
0
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
```
|
||||
|
||||
## <a name="S-core-tile-iterator"></a> Tile Iterator
|
||||
|
||||
The iterator design pattern provides an abstraction for accessing the items in a collection in sequence. Basic
|
||||
operators defined by iterators consist of accessing an item - either a load or store - followed by traversal to
|
||||
the next item in sequence.
|
||||
|
||||
<img src="/media/images/cutlass-tile-iteration.png" alt="CUTLASS tile access and traversal" width="50%" />
|
||||
|
||||
To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
|
||||
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
|
||||
|
||||
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
|
||||
|
||||
## <a name="S-core-fragment"></a> Fragment
|
||||
|
||||
A fragment is analogous to `std::array<>` in that it is a constant-sized array of elements. Typically backed by storage in the SM's register file, CUTLASS `Fragment<>` objects are used to store tiles. For threadblock- and warp-scope operations, the contents of these tiles are distributed across the partipcipating threads. In such cases, a thread's `Fragment<>` contains the part of the tile held by that thread.
|
||||
|
||||
## <a name="S-core-predicate-vector"></a> Predicate Vector
|
||||
|
||||
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
|
||||
|
||||
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
|
||||
|
||||
|
||||
# <a name="S-utilities"></a> 4. Utilities
|
||||
|
||||
CUTLASS implements efficient matrix multiply computations on GPUs. It is accompanied by an extensive utility
|
||||
framework offering features such as:
|
||||
|
||||
* [cutlass::half_t](tools/util/half.h) - a host-side half-precision type
|
||||
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
|
||||
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
|
||||
|
||||
|
||||
# <a name="S-optimization-strategies"></a>5. Optimization Strategies
|
||||
|
||||
This section describes several strategies taken to increase performance beyond what is achievable with
|
||||
a basic implementation of the hierarchical GEMM structure.
|
||||
|
||||
|
||||
## Threadblock Rasterization
|
||||
|
||||
To maximize reuse of data held in the last level cache, CUTLASS defines several functions to
|
||||
affect the mapping of threadblocks to logical partitions of the GEMM problem. These map
|
||||
consecutively launched threadblocks to packed two-dimensional regions of the partitioned GEMM
|
||||
problem to increase the probability that these will access the same tiles of global memory at
|
||||
approximately the same time.
|
||||
|
||||
Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](cutlass/gemm/threadblock_swizzle.h).
|
||||
|
||||
|
||||
## Parallel Reductions across GEMM _K_
|
||||
|
||||
Matrix product computations expose parallelism among _O(MN)_ independent inner product
|
||||
computations. For sufficiently large problem sizes, a GEMM kernel in CUTLASS may approach
|
||||
the theoretical maximum computational throughput. For small problems, however, there are
|
||||
too few threadblocks to efficiently occupy the entire GPU.
|
||||
|
||||
As a recourse, parallelizing the reduction performed during the inner product computation
|
||||
enables more threadblocks to execute concurrently while still taking advantage of the throughput
|
||||
benefits of large threadblock-level GEMM tiles.
|
||||
|
||||
CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
|
||||
and launching an additional set of threadblocks for each partition. Consequently, we refer to
|
||||
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called batched reduction.
|
||||
|
||||
The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the number of partition that will be applied along K dimension for operand A and B. For example, parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.
|
||||
|
||||
The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by the users to store this intermediate results.
|
||||
|
||||
An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_gemm.cu).
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
22
Doxyfile
22
Doxyfile
@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
|
||||
# title of most generated pages and in a few other places.
|
||||
# The default value is: My Project.
|
||||
|
||||
PROJECT_NAME = "Cutlass"
|
||||
PROJECT_NAME = "CUTLASS"
|
||||
|
||||
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
|
||||
# could be handy for archiving the generated documentation or if some version
|
||||
@ -51,7 +51,7 @@ PROJECT_BRIEF = "CUDA Templates for Linear Algebra Subroutines and Solv
|
||||
# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo
|
||||
# to the output directory.
|
||||
|
||||
PROJECT_LOGO =
|
||||
PROJECT_LOGO = media/images/cutlass-logo-small.png
|
||||
|
||||
# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
|
||||
# into which the generated documentation will be written. If a relative path is
|
||||
@ -206,7 +206,7 @@ SEPARATE_MEMBER_PAGES = NO
|
||||
# uses this value to replace tabs by spaces in code fragments.
|
||||
# Minimum value: 1, maximum value: 16, default value: 4.
|
||||
|
||||
TAB_SIZE = 4
|
||||
TAB_SIZE = 2
|
||||
|
||||
# This tag can be used to specify a number of aliases that act as commands in
|
||||
# the documentation. An alias has the form:
|
||||
@ -297,7 +297,7 @@ AUTOLINK_SUPPORT = YES
|
||||
# diagrams that involve STL classes more complete and accurate.
|
||||
# The default value is: NO.
|
||||
|
||||
BUILTIN_STL_SUPPORT = NO
|
||||
BUILTIN_STL_SUPPORT = YES
|
||||
|
||||
# If you use Microsoft's C++/CLI language, you should set this option to YES to
|
||||
# enable parsing support.
|
||||
@ -734,7 +734,9 @@ WARN_LOGFILE =
|
||||
# spaces.
|
||||
# Note: If this tag is empty the current directory is searched.
|
||||
|
||||
INPUT = cutlass
|
||||
INPUT = include/cutlass tools/util/include/cutlass/ tools/library/include/cutlass/
|
||||
|
||||
INPUT += media/docs/doxygen_mainpage.md
|
||||
|
||||
# This tag can be used to specify the character encoding of the source files
|
||||
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
|
||||
@ -870,7 +872,7 @@ FILTER_SOURCE_PATTERNS =
|
||||
# (index.html). This can be useful if you have a project on for instance GitHub
|
||||
# and want to reuse the introduction page also for the doxygen output.
|
||||
|
||||
USE_MDFILE_AS_MAINPAGE =
|
||||
USE_MDFILE_AS_MAINPAGE = media/docs/doxygen_mainpage.md
|
||||
|
||||
#---------------------------------------------------------------------------
|
||||
# Configuration options related to source browsing
|
||||
@ -999,7 +1001,7 @@ GENERATE_HTML = YES
|
||||
# The default directory is: html.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_OUTPUT = generated-html
|
||||
HTML_OUTPUT =
|
||||
|
||||
# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
|
||||
# generated HTML page (for example: .htm, .php, .asp).
|
||||
@ -1080,7 +1082,7 @@ HTML_EXTRA_FILES =
|
||||
# Minimum value: 0, maximum value: 359, default value: 220.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_COLORSTYLE_HUE = 82
|
||||
HTML_COLORSTYLE_HUE = 100
|
||||
|
||||
# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
|
||||
# in the HTML output. For a value of 0 the output will use grayscales only. A
|
||||
@ -1088,7 +1090,7 @@ HTML_COLORSTYLE_HUE = 82
|
||||
# Minimum value: 0, maximum value: 255, default value: 100.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_COLORSTYLE_SAT = 100
|
||||
HTML_COLORSTYLE_SAT = 50
|
||||
|
||||
# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
|
||||
# luminance component of the colors in the HTML output. Values below 100
|
||||
@ -1107,7 +1109,7 @@ HTML_COLORSTYLE_GAMMA = 80
|
||||
# The default value is: YES.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_TIMESTAMP = YES
|
||||
HTML_TIMESTAMP = NO
|
||||
|
||||
# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
|
||||
# documentation will contain sections that can be hidden and shown after the
|
||||
|
||||
188
EULA.txt
Normal file
188
EULA.txt
Normal file
@ -0,0 +1,188 @@
|
||||
NVIDIA Software License Agreement
|
||||
|
||||
IMPORTANT NOTICE – PLEASE READ AND AGREE BEFORE USING THE SOFTWARE
|
||||
This software license agreement (“Agreement”) is a legal agreement between you, whether an individual or entity, (“you”) and NVIDIA Corporation (“NVIDIA”) and governs the use of the NVIDIA CUTLASS DSLs software and materials that NVIDIA delivers to you under this Agreement (“Software”).
|
||||
NVIDIA and you are each a “party” and collectively the “parties.”
|
||||
This Agreement can be accepted only by an adult of legal age of majority in the country in which the Software is used.
|
||||
If you don’t have the required age or authority to accept this Agreement, or if you don’t accept all the terms and conditions of this Agreement, do not use the Software.
|
||||
|
||||
1. License Grants
|
||||
|
||||
1.1. License Grant to You. The Software made available by NVIDIA to you is licensed, not sold.
|
||||
Subject to the terms of this Agreement, NVIDIA grants you a limited, non-exclusive, revocable, non-transferable, and non-sublicensable (except as expressly granted in this Agreement), license to:
|
||||
|
||||
a. install and use copies of the Software,
|
||||
b. configure the Software using configuration files provided (if applicable),
|
||||
c. modify and create derivative works of any sample or example source code NVIDIA delivers to you as part of the Software (“Derivatives”) (if applicable), and
|
||||
d. distribute python files in the Software package in source format as incorporated into a software application subject to the following distribution requirements:
|
||||
|
||||
i. Your application must have material additional functionality, beyond the included portions of the Software.
|
||||
ii. The distributable portions of the Software shall only be accessed by your application.
|
||||
iii. The following notice shall be included in modifications and derivative works of sample source code distributed: “This software contains source code provided by NVIDIA Corporation.”
|
||||
iv. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
|
||||
v. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIA’s intellectual property rights.
|
||||
vi. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
|
||||
|
||||
The foregoing (a) through (d) are, collectively, the “Purpose”, and the developed applications are only for use in systems with NVIDIA GPUs.
|
||||
|
||||
1.2. License Grant to NVIDIA. Subject to the terms of this Agreement, you grant NVIDIA and its affiliates a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit at NVIDIA’s discretion any Derivatives created by or for you.
|
||||
You may, but are not required to, deliver any Derivatives to NVIDIA.
|
||||
|
||||
2. License Restrictions
|
||||
|
||||
Your license to use the Software and Derivatives is restricted as stated in this Section 2 (“License Restrictions”).
|
||||
You will cooperate with NVIDIA and, upon NVIDIA’s written request, you will confirm in writing and provide reasonably requested information to verify your compliance with the terms of this Agreement.
|
||||
You may not:
|
||||
|
||||
2.1. Use the Software or Derivatives for any purpose other than the Purpose;
|
||||
|
||||
2.2. Sell, rent, sublicense, transfer, distribute or otherwise make available to others (except authorized users as stated in Section 3 (“Authorized Users”)) any portion of the Software or Derivatives, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.3. Reverse engineer, decompile, or disassemble the Software components provided in binary form, nor attempt in any other manner to obtain source code of such Software;
|
||||
|
||||
2.4. Modify or create derivative works of the Software, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.5. Change or remove copyright or other proprietary notices in the Software;
|
||||
|
||||
2.6. Bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the Software;
|
||||
|
||||
2.7. Use the Software or Derivatives in any manner that would cause them to become subject to an open source software license, subject to the terms in Section 6 (“Components Under Other Licenses”);
|
||||
|
||||
2.8. Use the Software or Derivatives in violation of any applicable law or regulation in relevant jurisdictions
|
||||
|
||||
2.9. Indicate that a product or service developed with the Software or Derivatives is sponsored or endorsed by NVIDIA;
|
||||
|
||||
2.10. Replace any NVIDIA software components in the Software that are governed by this Agreement with other software that implements NVIDIA APIs;
|
||||
|
||||
2.11. Reverse engineer, decompile or disassemble any portion of the output generated using Software elements for the purpose of translating such output artifacts to target a non-NVIDIA platform; or
|
||||
|
||||
3. Authorized Users
|
||||
|
||||
You may allow employees and contractors of your entity or of your subsidiary(ies), and for educational institutions also enrolled students, to internally access and use the Software as authorized by this Agreement from your secure network to perform the work authorized by this Agreement on your behalf.
|
||||
You are responsible for the compliance with the terms of this Agreement by your authorized users.
|
||||
Any act or omission that if committed by you would constitute a breach of this Agreement will be deemed to constitute a breach of this Agreement if committed by your authorized users.
|
||||
|
||||
4. Pre-Release
|
||||
|
||||
Software versions identified as alpha, beta, preview, early access or otherwise as pre-release (“Pre-Release”) may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability and reliability standards relative to NVIDIA commercial offerings.
|
||||
You use Pre-Release Software at your own risk. NVIDIA did not design or test the Software for use in production or business-critical systems.
|
||||
NVIDIA may choose not to make available a commercial version of Pre-Release Software.
|
||||
NVIDIA may also choose to abandon development and terminate the availability of Pre-Release Software at any time without liability.
|
||||
|
||||
5. Updates
|
||||
|
||||
NVIDIA may at any time and at its option, change, discontinue, or deprecate any part, or all, of the Software, or change or remove features or functionality, or make available patches, workarounds or other updates to the Software.
|
||||
Unless the updates are provided with their separate governing terms, they are deemed part of the Software licensed to you under this Agreement, and your continued use of the Software is deemed acceptance of such changes.
|
||||
|
||||
6. Components Under Other Licenses
|
||||
|
||||
The Software may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms (“Other Licenses”).
|
||||
The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights;
|
||||
except that this Agreement will prevail regarding the use of third-party open source software, unless a third-party open source software license requires its license terms to prevail.
|
||||
Open source software license means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (http://opensource.org), Free Software Foundation (http://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (http://www.spdx.org).
|
||||
|
||||
7. Ownership
|
||||
|
||||
7.1. NVIDIA Ownership. The Software, including all intellectual property rights, is and will remain the sole and exclusive property of NVIDIA or its licensors.
|
||||
Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Software, and (b) no other license or right is granted to you by implication, estoppel or otherwise.
|
||||
|
||||
7.2. Your Ownership. Subject to the rights of NVIDIA and its suppliers in the Software, which continue to be licensed as stated in this Agreement, even when incorporated in your products or services, and the extent permitted by applicable law, as between you and NVIDIA, you hold all rights, title and interest in and to your products, services and Derivatives you develop as permitted in this Agreement including their respective intellectual property rights.
|
||||
|
||||
8. Feedback
|
||||
|
||||
You may, but you are not obligated to, provide suggestions, requests, fixes, modifications, enhancements, or other feedback regarding the Software (collectively, “Feedback”).
|
||||
Feedback, even if designated as confidential by you, will not create any confidentiality obligation for NVIDIA or its affiliates.
|
||||
If you provide Feedback, you grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIA’s discretion.
|
||||
|
||||
9. Termination
|
||||
|
||||
9.1. Termination. This Agreement will automatically terminate without notice from NVIDIA if you fail to comply with any of the terms in this Agreement or if you commence or participate in any legal proceeding against NVIDIA with respect to the Software.
|
||||
Additionally, either party may terminate this Agreement at any time with thirty (30) days’ advance written notice to the other party.
|
||||
|
||||
9.2. Effect of Termination. Upon any expiration or termination of this Agreement, you will promptly (a) stop using and return, delete or destroy NVIDIA confidential information and all Software received under this Agreement, and (b) delete or destroy Derivatives created under this Agreement, unless an authorized NVIDIA representative provides prior written approval that you may keep a copy of the Derivatives solely for archival purposes.
|
||||
Upon written request, you will certify in writing that you have complied with your obligations under this Section 9.2 (“Effect of Termination”).
|
||||
|
||||
9.3. Survival. Section 1.2 (“License Grant to NVIDIA”), Section 5 (“Updates”), Section 6 (“Components Under Other Licenses”), Section 7 (“Ownership”), Section 8 (“Feedback), Section 9.2 (“Effect of Termination”), Section 9.3 (“Survival”), Section 10 (“Disclaimer of Warranties”), Section 11 (“Limitation of Liability”), Section 12 (“Use in Mission Critical Applications”), Section 13 (“Governing Law and Jurisdiction”), Section 14 (“Indemnity”) and Section 15 (“General”) will survive any expiration or termination of this Agreement.
|
||||
|
||||
10. Disclaimer of Warranties
|
||||
|
||||
THE SOFTWARE IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND REPRESENTATIONS OF ANY KIND, WHETHER
|
||||
EXPRESS, IMPLIED OR STATUTORY, RELATING TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT OR ASSUME RESPONSIBILITY FOR THE ACCURACY OR COMPLETENESS OF ANY THIRD-PARTY INFORMATION, TEXT, GRAPHICS, LINKS CONTAINED IN THE SOFTWARE.
|
||||
WITHOUT LIMITING THE FOREGOING, NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS, ANY DEFECTS OR ERRORS WILL BE CORRECTED, ANY CERTAIN CONTENT WILL BE AVAILABLE; OR THAT THE SOFTWARE IS FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS. NO INFORMATION OR ADVICE GIVEN BY NVIDIA WILL IN ANY WAY INCREASE THE SCOPE OF ANY WARRANTY EXPRESSLY PROVIDED IN THIS AGREEMENT.
|
||||
NVIDIA does not warrant or assume responsibility for the accuracy or completeness of any third-party information, text, graphics or links contained in the Software.
|
||||
|
||||
11. Limitations of Liability
|
||||
|
||||
11.1. EXCLUSIONS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL NVIDIA BE LIABLE FOR ANY (I) INDIRECT, PUNITIVE, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES, OR (ii) DAMAGES FOR (a) THE COST OF PROCURING SUBSTITUTE GOODS, OR (b) LOSS OF PROFITS, REVENUES, USE, DATA OR GOODWILL ARISING OUT OF OR RELATED TO THIS AGREEMENT, WHETHER BASED ON BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, AND EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND EVEN IF A PARTY’S REMEDIES FAIL THEIR ESSENTIAL PURPOSE.
|
||||
|
||||
11.2. DAMAGES CAP. ADDITIONALLY, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA’S TOTAL CUMULATIVE AGGREGATE LIABILITY FOR ANY AND ALL LIABILITIES, OBLIGATIONS OR CLAIMS ARISING OUT OF OR RELATED TO THIS AGREEMENT WILL NOT EXCEED FIVE U.S. DOLLARS (US$5).
|
||||
|
||||
12. Use in Mission Critical Applications
|
||||
|
||||
You acknowledge that the Software provided under this Agreement is not designed or tested by NVIDIA for use in any system or application where the use or failure of such system or application developed with NVIDIA’s Software could result in injury, death or catastrophic damage (each, a “Mission Critical Application”).
|
||||
Examples of Mission Critical Applications include use in avionics, navigation, autonomous vehicle applications, AI solutions for automotive products, military, medical, life support or other mission-critical or life-critical applications.
|
||||
NVIDIA will not be liable to you or any third party, in whole or in part, for any claims or damages arising from these uses.
|
||||
You are solely responsible for ensuring that systems and applications developed with the Software include sufficient safety and redundancy features and comply with all applicable legal and regulatory standards and requirements.
|
||||
|
||||
13. Governing Law and Jurisdiction
|
||||
|
||||
This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods.
|
||||
The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts;
|
||||
except that either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
|
||||
|
||||
14. Indemnity
|
||||
|
||||
By using the Software you agree to defend, indemnify and hold harmless NVIDIA and its affiliates and their respective officers, directors, employees and agents from and against any claims, disputes, demands, liabilities, damages, losses, costs and expenses arising out of or in any way connected with (i) products or services that have been developed or deployed with or use the Software, or claims that they violate laws, or infringe, violate, or misappropriate any third party right;
|
||||
or (ii) use of the Software in breach of the terms of this Agreement.
|
||||
|
||||
15. General
|
||||
|
||||
15.1. Independent Contractors.
|
||||
The parties are independent contractors, and this Agreement does not create a joint venture, partnership, agency, or other form of business association between the parties.
|
||||
Neither party will have the power to bind the other party or incur any obligation on its behalf without the other party’s prior written consent.
|
||||
Nothing in this Agreement prevents either party from participating in similar arrangements with third parties.
|
||||
|
||||
15.2. No Assignment.
|
||||
NVIDIA may assign, delegate or transfer its rights or obligations under this Agreement by any means or operation of law.
|
||||
You may not, without NVIDIA’s prior written consent, assign, delegate or transfer any of your rights or obligations under this Agreement by any means or operation of law, and any attempt to do so is null and void.
|
||||
|
||||
15.3. No Waiver.
|
||||
No failure or delay by a party to enforce any term or obligation of this Agreement will operate as a waiver by that party, or prevent the enforcement of such term or obligation later.
|
||||
|
||||
15.4. Trade Compliance.
|
||||
You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations.
|
||||
You confirm (a) your understanding that export or reexport of certain NVIDIA products or technologies may require a license or other approval from appropriate authorities and (b) that you will not export or reexport any products or technology, directly or indirectly, without first obtaining any required license or other approval from appropriate authorities, (i) to any countries that are subject to any U.S. or local export restrictions (currently including, but not necessarily limited to, Belarus, Cuba, Iran, North Korea, Russia, Syria, the Region of Crimea, Donetsk People’s Republic Region and Luhansk People’s Republic Region);
|
||||
(ii) to any end-user who you know or have reason to know will utilize them in the design, development or production of nuclear, chemical or biological weapons, missiles, rocket systems, unmanned air vehicles capable of a maximum range of at least 300 kilometers, regardless of payload, or intended for military end-use, or any weapons of mass destruction;
|
||||
(iii) to any end-user who has been prohibited from participating in the U.S. or local export transactions by any governing authority;
|
||||
or (iv) to any known military or military-intelligence end-user or for any known military or military-intelligence end-use in accordance with U.S. trade compliance laws and regulations.
|
||||
|
||||
15.5. Government Rights.
|
||||
The Software, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R.
|
||||
2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R.
|
||||
12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense;
|
||||
(ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement;
|
||||
and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense.
|
||||
In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R.
|
||||
52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing.
|
||||
|
||||
15.6. Notices.
|
||||
Please direct your legal notices or other correspondence to legalnotices@nvidia.com with a copy mailed to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department.
|
||||
If NVIDIA needs to contact you, you consent to receive the notices by email and agree that such notices will satisfy any legal communication requirements.
|
||||
|
||||
15.7. Severability.
|
||||
If a court of competent jurisdiction rules that a provision of this Agreement is unenforceable, that provision will be deemed modified to the extent necessary to make it enforceable and the remainder of this Agreement will continue in full force and effect.
|
||||
|
||||
15.8. Amendment.
|
||||
Any amendment to this Agreement must be in writing and signed by authorized representatives of both parties.
|
||||
|
||||
15.9. Construction.
|
||||
The headings in the Agreement are included solely for convenience and are not intended to affect the meaning or interpretation of the Agreement.
|
||||
As required by the context of the Agreement, the singular of a term includes the plural and vice versa.
|
||||
|
||||
15.10. Force Majeure.
|
||||
Neither party will be liable during any period where an event or circumstance prevents or delays that party from performing its obligations under this Agreement and that event or circumstance: (i) is not within the reasonable control of that party and is not the result of that party’s negligence, and (ii) cannot be overcome or avoided by that party using reasonably diligent efforts.
|
||||
|
||||
15.11. Entire Agreement.
|
||||
Regarding the subject matter of this Agreement, the parties agree that (a) this Agreement constitutes the entire and exclusive agreement between the parties and supersedes all prior and contemporaneous communications and (b) any additional or different terms or conditions, whether contained in purchase orders, order acknowledgments, invoices or otherwise, will not be binding and are null and void.
|
||||
|
||||
(v. May 8, 2025)
|
||||
30
FUNCTIONALITY.md
Normal file
30
FUNCTIONALITY.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Changelog for CuTe DSL API changes
|
||||
|
||||
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
|
||||
|
||||
* for loop
|
||||
- Python built-in ``range`` now always generates IR and executes at runtime
|
||||
- ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control
|
||||
- Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range``
|
||||
- **Experimental** Added ``pipelining`` control for compiler generated software pipeline code
|
||||
* while/if
|
||||
- ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate
|
||||
- Deprecated ``cutlass.dynamic_expr``, please remove it
|
||||
* Rename mbarrier functions to reduce ambiguity
|
||||
* Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier`
|
||||
* Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional.
|
||||
* Introduce `cutlass.cute.arch.get_dyn_smem_size` api to get runtime dynamic shared memory size.
|
||||
* Various API Support for SM100 BlockScaled Gemm
|
||||
- Introduce BlockScaled MmaOps in [tcgen05/mma.py]([https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py]), and provide a `make_blockscaled_trivial_tiled_mma` function in [blackwell_helpers.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py) to help construct a BlockScaled TiledMma.
|
||||
- Introduce S2T CopyOps in [tcgen05/copy.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py).
|
||||
- Introduce BlockScaled layout utilities in [blockscaled_layout.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py) for creating the required scale factor layouts in global memory, shared memory and tensor memory.
|
||||
* `cutlass.cute.compile` now supports compilation options. Refer to [JIT compilation options](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html) for more details.
|
||||
* `cutlass.cute.testing.assert_` now works for device JIT function. Specify `--enable-device-assertions` as compilation option to enable.
|
||||
* `cutlass.cute.make_tiled_copy` is now deprecated. Please use `cutlass.cute.make_tiled_copy_tv` instead.
|
||||
* Shared memory capacity query
|
||||
- Introduce `cutlass.utils.get_smem_capacity_in_bytes` for querying the shared memory capacity.
|
||||
- `<arch>_utils.SMEM_CAPACITY["<arch_str>"]` is now deprecated.
|
||||
|
||||
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||
|
||||
* Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
23
LICENSE.TXT
23
LICENSE.TXT
@ -1,23 +0,0 @@
|
||||
Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
34
LICENSE.txt
Normal file
34
LICENSE.txt
Normal file
@ -0,0 +1,34 @@
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
Certain files within this repository are subject to separate licensing terms:
|
||||
|
||||
- The files located in the `python/CuTeDSL` directory are licensed under the
|
||||
NVIDIA End User License Agreement (EULA). Please refer to
|
||||
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
for the full terms.
|
||||
104
PUBLICATIONS.md
Normal file
104
PUBLICATIONS.md
Normal file
@ -0,0 +1,104 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2025
|
||||
|
||||
- ["Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts"](https://arxiv.org/abs/2502.19811). Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen, Xin Liu. _arXiv_, February 2025.
|
||||
|
||||
- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.
|
||||
|
||||
- ["Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light"](https://arxiv.org/abs/2504.16922). Ali Hassani, Fengzhe Zhou, Aditya Kane, Jiannan Huang, Chieh-Yun Chen, Min Shi, Steven Walton, Markus Hoehnerbach, Vijay Thakkar, Michael Isaev, Qinsheng Zhang, Bing Xu, Haicheng Wu, Wen-mei Hwu, Ming-Yu Liu, Humphrey Shi. _arXiv_, April 2025.
|
||||
|
||||
## 2024
|
||||
|
||||
- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024.
|
||||
|
||||
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
|
||||
|
||||
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
|
||||
|
||||
- ["EVT: Accelerating Deep Learning Training with Epilogue Visitor Tree"](https://dl.acm.org/doi/10.1145/3620666.3651369). Zhaodong Chen, Andrew Kerr, Richard Cai, Jack Kosaian, Haicheng Wu, Yufei Ding, and Yuan Xie. _Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, April 2024.
|
||||
|
||||
- ["Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level"](https://arxiv.org/abs/2403.04690). Ali Hassani, Wen-Mei Hwu, Humphrey Shi. _arXiv_, March 2024.
|
||||
|
||||
## 2023
|
||||
|
||||
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
|
||||
|
||||
- ["Benchmarking GPU Tensor Cores on General Matrix Multiplication Kernels through CUTLASS"](https://www.mdpi.com/2076-3417/13/24/13022). Xuanteng Huang, Xianwei Zhang, Panfei Yang, Nong Xiao. _Journal of Applied Sciences_, December 2023.
|
||||
|
||||
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.
|
||||
|
||||
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
|
||||
|
||||
- ["MegaBlocks: Efficient Sparse Training with Mixture-of-Experts"](https://arxiv.org/abs/2211.15841). Trevor Gale, Deepak Narayanan, Cliff Young, Matei Zaharia. _Proceedings of the Sixth Machine Learning and Systems_, May 2023.
|
||||
|
||||
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
||||
|
||||
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
|
||||
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, February 2023.
|
||||
|
||||
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, February 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
## 2022
|
||||
|
||||
- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022.
|
||||
|
||||
- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022.
|
||||
|
||||
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
|
||||
|
||||
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
|
||||
|
||||
- ["Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads"](https://arxiv.org/abs/2105.05720). Abhinav Jangda, Jun Huang, Guodong Liu, Amir Hossein Nodehi Sabet, Saeed Maleki, Youshan Miao, Madanlal Musuvathi, Todd Mytkowicz, Olli Sarikivi. _Proceedings of the 27th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, February 2022.
|
||||
|
||||
## 2021
|
||||
|
||||
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
|
||||
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
|
||||
## 2020
|
||||
|
||||
- ["Scalable Knowledge Graph Analytics at 136 Petaflop/s"](https://www.computer.org/csdl/proceedings-article/sc/2020/999800a061/1oeORDgCM0g). Ramakrishnan Kannan, Piyush Sao, Hao Lu, Drahomira Herrmannova, Vijay Thakkar, Robert Patton, Richard Vuduc, Thomas Potok. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Accelerating Sparse DNN Models without Hardware-Support via Tile-Wise Sparsity
|
||||
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
725
README.md
725
README.md
@ -1,139 +1,281 @@
|
||||

|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 1.3
|
||||
# CUTLASS 4.1.0
|
||||
|
||||
_CUTLASS 1.3.2 - July 2019_
|
||||
_CUTLASS 4.1.0 - July 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
reusable, modular software components abstracted by C++ template classes. These
|
||||
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
and applications.
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
hierarchical decomposition and data movement. CUTLASS decomposes these "moving parts" into reusable, modular
|
||||
software components and abstractions.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for 8-bit integer, half-precision floating
|
||||
point (FP16), single-precision floating point (FP32), and double-precision floating
|
||||
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond. Even faster performance on Volta is possible via direct access to
|
||||
Volta Tenor Cores via `mma.sync` (added in CUDA 10.1).
|
||||
Primitives for different levels of a conceptual parallelization hierarchy can be specialized and tuned
|
||||
via custom tiling sizes, data types, and other algorithmic policy. The resulting flexibility simplifies
|
||||
their use as building blocks within custom kernels and applications.
|
||||
|
||||
CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
We describe the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
CUTLASS has been providing CUDA C++ template abstractions for high-performance linear algebra since 2017 and
|
||||
these abstractions provide extensive support for a wide range of computations including
|
||||
mixed-precision computations, specialized data-movement (async copy) and
|
||||
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
|
||||
[FP32 emulation via tensor core instruction](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
8b floating point types (e5m2 and e4m3),
|
||||
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
|
||||
narrow integer types (4 and 8b signed and unsigned integers),
|
||||
and binary 1b data types (where architectures allow for the
|
||||
native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
|
||||
To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise.
|
||||
|
||||
# What's New in CUTLASS 1.3
|
||||
_March 2019_
|
||||
* CUTLASS 1.3 includes an efficient GEMM implementation with the `mma.sync` instruction added in CUDA 10.1.
|
||||
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions — exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
|
||||
|
||||
# What's New in CUTLASS 1.2
|
||||
_October 2018_
|
||||
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
|
||||
* Batched strided WMMA GEMM
|
||||
CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Ampere, Hopper, and Blackwell architectures.
|
||||
|
||||
We believe it will become an indispensable tool for students, researchers, and performance
|
||||
engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel
|
||||
designs, and bringing optimized solutions into production.
|
||||
|
||||
# What's New in CUTLASS 1.1
|
||||
_September 2018_
|
||||
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
||||
|
||||
* [CUTLASS Documentation](CUTLASS.md)
|
||||
* [Examples](examples/)
|
||||
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
|
||||
* Turing Features
|
||||
* [WMMA GEMM targeting TensorCores](tools/test/unit/gemm/wmma_integer_gemm.cu) - INT8, INT4, 1-bit
|
||||
* [Batched Strided GEMM](tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu)
|
||||
* [Threadblock rasterization strategies](tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu)
|
||||
* Improved performance for adverse problem sizes and data layouts
|
||||
* Extended CUTLASS Core components
|
||||
* Tensor views support arbitrary matrix and tensor layouts
|
||||
* Zip iterators for structuring multiple data streams
|
||||
* Enhanced CUTLASS utilities
|
||||
* [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
|
||||
* Added `HostMatrix<>` for simplified matrix creation
|
||||
To get started quickly - please refer :
|
||||
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
|
||||
|
||||
For all updates, see the [CUTLASS changelog](CHANGELOG.md).
|
||||
# What's New in CUTLASS 4.1
|
||||
|
||||
## CuTe DSL
|
||||
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
|
||||
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
|
||||
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
|
||||
## CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add variable sequence length support for FMHA Backward kernel.
|
||||
- Add varlen test support to Backward runner.
|
||||
- Codes support empty batch sequences.
|
||||
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
|
||||
* CuTe changes:
|
||||
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
|
||||
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
|
||||
- Factor out `print_latex` and friends and rewrite.
|
||||
- Factor out `print_svg` and friends and rewrite.
|
||||
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
|
||||
* Support residual add for implicit gemm kernels.
|
||||
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
|
||||
- Add verifier for sm90 to report the invalid input.
|
||||
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
|
||||
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
|
||||
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
|
||||
* Fix profiler bugs in exhaustive perf search.
|
||||
- Fix incorrect cluster shape output issue when doing exhaustive search.
|
||||
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
|
||||
* Fix some profiler issues.
|
||||
- Complete the reference for Blackwell blockwise gemm kernels.
|
||||
- Fix incorrect regex logic for L1 test.
|
||||
|
||||
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
|
||||
when compiled with CUDA 10.0.
|
||||
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
|
||||
|
||||

|
||||
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||

|
||||

|
||||
|
||||
# CuTe
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for
|
||||
defining and operating on hierarchically multidimensional layouts of threads and data.
|
||||
CuTe provides `Layout` and `Tensor` objects that compactly package the type,
|
||||
shape, memory space, and layout of data, while performing the complicated indexing for the user.
|
||||
This lets programmers focus on the logical descriptions of their algorithms while
|
||||
CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design,
|
||||
implement, and modify all dense linear algebra operations.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts
|
||||
which can be composed with data arrays to represent tensors.
|
||||
The representation of layouts is powerful enough to represent nearly
|
||||
everything we need to implement efficient dense linear algebra.
|
||||
Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS performs best when compiled with the [CUDA 10.1 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 9.0, 9.1, 9.2, and 10.0.
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta (compute capability 7.0)
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
|
||||
|
||||
## Operating Systems
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 14.04 | GCC 4.8.2 |
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| Ubuntu 18.04 | GCC 7.3.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, Volta-, and Turing-architecture NVIDIA GPUs.
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
|**GPU**|
|
||||
|---|
|
||||
|NVIDIA GeForce 1080|
|
||||
|NVIDIA TitanXP|
|
||||
|NVIDIA Tesla P100|
|
||||
|NVIDIA Tesla V100|
|
||||
|NVIDIA TitanV|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
## Hardware
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|
||||
|---|---|---|
|
||||
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|
||||
|NVIDIA TitanV |7.0|11.4|
|
||||
|NVIDIA GeForce RTX 20x0 series |7.5|11.4|
|
||||
|NVIDIA T4 |7.5|11.4|
|
||||
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|
||||
|NVIDIA A10 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 30x0 series |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 40x0 series |8.9|11.8|
|
||||
|NVIDIA L40 |8.9|11.8|
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures
|
||||
(i.e., it is forward compatible).
|
||||
However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose
|
||||
PTX does not have forward compatibility guarantees.
|
||||
Several Hopper and Blackwell PTX instructions fall under this category of
|
||||
architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture
|
||||
(note the "a" appended). For more details on this and other architecture-accelerated instructions,
|
||||
please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag
|
||||
`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100,
|
||||
users are required to build CUTLASS with `90a` as the target architecture.
|
||||
If a user accidentally builds a kernel which uses SM90a features
|
||||
(e.g. Hopper Tensor Core Instructions), using the SM90 target
|
||||
(note the lack of "a"), with either CUDA Toolkit 12 or 11.8,
|
||||
the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
Or
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
```
|
||||
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
|
||||
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
|
||||
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
CUTLASS is a header-only template library and does not need to be built to be used by other
|
||||
projects. However, we distribute extensive unit tests and utility programs to demonstrate
|
||||
CUTLASS. These instructions are for building those test programs.
|
||||
projects. Client applications should target CUTLASS's `include/` directory in their include
|
||||
paths.
|
||||
|
||||
CUTLASS's unit tests depend on Google Test which exists as a git submodule. You can fetch
|
||||
submodules as follows.
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
```
|
||||
$ git submodule update --init --recursive
|
||||
```bash
|
||||
$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0.
|
||||
To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake once.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ mkdir build && cd build
|
||||
$ cmake ..
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture
|
||||
```
|
||||
|
||||
Compile the CUTLASS project by running Make. Include the -j argument to compile sources in
|
||||
parallel and speed up the build process.
|
||||
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
|
||||
|
||||
```
|
||||
$ make -j12
|
||||
...
|
||||
$
|
||||
```
|
||||
The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS,
|
||||
and they may be executed in parallel via make's `-j` command line argument.
|
||||
|
||||
Verify CUTLASS has been built correctly by running the unit tests from the build/ directory.
|
||||
|
||||
```
|
||||
$ ./tools/test/unit/cutlass_unit_test
|
||||
```bash
|
||||
$ make test_unit -j
|
||||
...
|
||||
...
|
||||
...
|
||||
@ -142,125 +284,348 @@ $ ./tools/test/unit/cutlass_unit_test
|
||||
[ PASSED ] 946 tests.
|
||||
```
|
||||
|
||||
All tests should pass, though the exact number of tests may vary over time.
|
||||
All tests should pass on supported platforms, though the exact number of tests may vary over time.
|
||||
|
||||
|
||||
# Project Structure
|
||||
|
||||
CUTLASS is arranged as a header-only library with several example test programs
|
||||
that demonstrate instantiating a GEMM task within a CUDA kernel. The Doxygen documentation
|
||||
provides a complete list of files, classes, and template concepts defined in the CUTLASS
|
||||
project. A brief summary is described below.
|
||||
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
The CUTLASS library is defined in the cutlass/ directory and consists of CUDA C++ template
|
||||
classes and other definitions for implementing efficient GPU GEMM kernels. A set of core
|
||||
classes and templates define basic primitives that are then applied to compute GEMM via
|
||||
templates in the cutlass/gemm directory.
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
```
|
||||
cutlass/
|
||||
gemm/
|
||||
util/
|
||||
<core API components>
|
||||
```
|
||||
include/ # client applications should target this directory in their build's include paths
|
||||
|
||||
Several tools and test programs are also distributed with the CUTLASS library. They are
|
||||
contained in the following directories.
|
||||
cutlass/ # CUDA Templates for Linear Algebra Subroutines and Solvers - headers only
|
||||
|
||||
arch/ # direct exposure of architecture features (including instruction-level GEMMs)
|
||||
|
||||
conv/ # code specialized for convolution
|
||||
|
||||
epilogue/ # code specialized for the epilogue of gemm/convolution
|
||||
|
||||
gemm/ # code specialized for general matrix product computations
|
||||
|
||||
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
|
||||
|
||||
platform/ # CUDA-capable Standard Library components
|
||||
|
||||
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
|
||||
|
||||
thread/ # simt code that can be performed within a CUDA thread
|
||||
|
||||
transform/ # code specialized for layout, type, and domain transformations
|
||||
|
||||
* # core vocabulary types, containers, and basic numeric operations
|
||||
|
||||
cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy
|
||||
|
||||
algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples
|
||||
|
||||
arch/ # Bare bones PTX wrapper structs for copy and math instructions
|
||||
|
||||
atom/ # Meta-information either link to or built from arch/ operators
|
||||
|
||||
mma_atom.hpp # cute::Mma_Atom and cute::TiledMma
|
||||
|
||||
copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy
|
||||
|
||||
*sm*.hpp # Arch specific meta-information for copy and math operations
|
||||
|
||||
* # Core library types such as Shape, Stride, Layout, Tensor, and associated operations
|
||||
|
||||
```
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
[CUTLASS SDK examples](https://github.com/NVIDIA/cutlass/tree/main/examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
### Tools
|
||||
|
||||
```
|
||||
examples/
|
||||
00_basic_gemm/
|
||||
01_tensor_view/
|
||||
02_cutlass_utilities/
|
||||
03_batched_gemm/
|
||||
04_tile_iterator/
|
||||
05_wmma_gemm/
|
||||
tools/
|
||||
test/
|
||||
unit/
|
||||
core/
|
||||
gemm/
|
||||
perf/
|
||||
util/
|
||||
reference/
|
||||
device/
|
||||
host/
|
||||
<utilities>
|
||||
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
|
||||
include/
|
||||
cutlass/
|
||||
library/
|
||||
|
||||
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
|
||||
# CUTLASS Library
|
||||
|
||||
util/ # CUTLASS Utilities - contains numerous helper classes for
|
||||
include/ # managing tensors in device memory, reference
|
||||
cutlass/ # implementations for GEMM, random initialization
|
||||
util/ # of tensors, and I/O.
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
The `tools/util` directory contains CUTLASS utilities including reference implementations of GEMM and
|
||||
several element-wise tensor operations.
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
Its usage is shown below.
|
||||
The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
It can be built as follows:
|
||||
|
||||
Program usage:
|
||||
```bash
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
## Building all GEMM and Convolution kernels (_long_ build times)
|
||||
|
||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
Beware, this results in *tens of thousands* of kernels and long build times.
|
||||
This would also result in a large binary size and on some platforms linker to fail on building the library.
|
||||
Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below.
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
## Building a subset of GEMM and Convolution kernels (_reduced_ build times)
|
||||
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
wildcard characters may be used to reduce the set of kernels. The following examples show building exactly one
|
||||
or a subset of kernels for NVIDIA Ampere and Turing architecture:
|
||||
|
||||
### Building a subset Tensor Core GEMM kernels
|
||||
|
||||
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture,
|
||||
use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling a subset of Tensor Core GEMM kernels is as follows:
|
||||
```bash
|
||||
./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*gemm_f16_*_nt_align8 --m=3456 --n=4096 --k=4096
|
||||
|
||||
...
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: gemm
|
||||
Operation: cutlass_tensorop_s1688gemm_f16_256x128_32x2_nt_align8
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
cuBLAS: Passed
|
||||
|
||||
Arguments: --gemm_kind=universal --m=3456 --n=4096 --k=4096 --A=f16:column --B=f16:row --C=f32:column --alpha=1 \
|
||||
--beta=0 --split_k_slices=1 --batch_count=1 --op_class=tensorop --accum=f32 --cta_m=256 --cta_n=128 \
|
||||
--cta_k=32 --stages=2 --warps_m=4 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 --min_cc=75 \
|
||||
--max_cc=1024
|
||||
|
||||
Bytes: 118489088 bytes
|
||||
FLOPs: 115992428544 flops
|
||||
|
||||
Runtime: 1.55948 ms
|
||||
Memory: 70.7616 GiB/s
|
||||
|
||||
Math: 74378.8 GFLOP/s
|
||||
|
||||
|
||||
|
||||
=============================
|
||||
...
|
||||
```
|
||||
|
||||
### Building one CUDA Core GEMM kernel
|
||||
|
||||
To compile one SGEMM kernel targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling single SGEMM CUDA kernel is as follows:
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: gemm
|
||||
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
cuBLAS: Passed
|
||||
|
||||
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
|
||||
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Bytes: 180355072 bytes
|
||||
FLOPs: 115992428544 flops
|
||||
|
||||
Runtime: 6.73655 ms
|
||||
Memory: 24.934 GiB/s
|
||||
|
||||
Math: 17218.4 GFLOP/s
|
||||
|
||||
=============================
|
||||
```
|
||||
|
||||
### Building a subset of Tensor Core Convolution kernels
|
||||
|
||||
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
|
||||
and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling a subset of Tensor Core convolution kernels is as follows:
|
||||
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*fprop_optimized_f16 --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
|
||||
|
||||
...
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: conv2d
|
||||
Operation: cutlass_tensorop_s16816fprop_optimized_f16_128x128_32x5_nhwc
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
|
||||
Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \
|
||||
--stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f16:nhwc --Filter=f16:nhwc --Output=f32:nhwc \
|
||||
--conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \
|
||||
--eq_gemm_provider=none --op_class=tensorop --accum=f32 --cta_m=128 --cta_n=128 --cta_k=32 --stages=5 \
|
||||
--warps_m=2 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=16 --min_cc=80 --max_cc=1024
|
||||
|
||||
Bytes: 1130659840 bytes
|
||||
FLOPs: 118482796544 flops
|
||||
|
||||
Runtime: 0.711496 ms
|
||||
Memory: 1479.99 GiB/s
|
||||
|
||||
Math: 166526 GFLOP/s
|
||||
|
||||
=============================
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
### Building one Convolution CUDA kernel
|
||||
|
||||
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
|
||||
and FP32 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling one CUDA Core convolution kernel:
|
||||
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
|
||||
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: conv2d
|
||||
Operation: cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
|
||||
Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \
|
||||
--stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f32:nhwc --Filter=f32:nhwc --Output=f32:nhwc \
|
||||
--conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \
|
||||
--eq_gemm_provider=none --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Bytes: 2055798784 bytes
|
||||
FLOPs: 118482796544 flops
|
||||
|
||||
Runtime: 7.34266 ms
|
||||
Memory: 260.752 GiB/s
|
||||
|
||||
Math: 16136.2 GFLOP/s
|
||||
|
||||
|
||||
=============================
|
||||
|
||||
```
|
||||
cutlass_perf_test [options]
|
||||
|
||||
--help
|
||||
--append=<true|false*> If true, appends output to existing CSV file. If false, overwrites.
|
||||
--alpha=<alpha> Value for alpha to be used in GEMM experiments
|
||||
--beta=<beta> Value for beta to be used in GEMM experiments
|
||||
--dist=<distribution> Describes the random distribution of each of the input matrix operands.
|
||||
--execution_mode=<mode> Specifies execution mode: profile, verify, single
|
||||
--output=<filename.csv> Writes summary of profiling to specified .csv file
|
||||
--iterations=<timing iterations> maximum number of iterations to execute when profiling
|
||||
--m=<height>[:max height[:step]] Height of GEMM problem (number of rows of C). May specify a range with optional step size.
|
||||
--n=<width>[:max width[:step]] Width of GEMM problem (number of columns of C). May specify a range with optional step size.
|
||||
--k=<depth>[:max depth[:step]] Size of inner dimension of A and B. May specify a range with optional step size.
|
||||
--kernels=<{s|d|h|i|wmma_}gemm_{nn,nt,tn,tt}> Select GEMM datatype and layout to use for tests
|
||||
--peak=<bool> If true, only reports peak performance per kernel after profiling specified problem space.
|
||||
--save_workspace={*never,incorrect,always} Specifies when to save the GEMM inputs and results to the filesystem.
|
||||
--seed=<seed> Random seed used by the random number generator in initializing input matrices.
|
||||
--tags=<column:tag,...> Inserts leading columns in output table and uniform values for each column.
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
# Runs one problem size for all kernels
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024
|
||||
|
||||
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
|
||||
|
||||
# Executes GEMM kernel on Volta Tensor Cores
|
||||
$ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt
|
||||
```
|
||||
|
||||
# About
|
||||
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
3-clause "New" BSD license.
|
||||
[3-clause "New" BSD license](LICENSE.txt).
|
||||
|
||||
# Contributors
|
||||
|
||||
The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md).
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
54
bin2hex.cmake
Normal file
54
bin2hex.cmake
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2019 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# A small utility function which generates a C-header from an input file
|
||||
function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
FILE(READ "${FILENAME}" HEX_INPUT HEX)
|
||||
if (${ZERO_TERMINATED})
|
||||
string(APPEND HEX_INPUT "00")
|
||||
endif()
|
||||
|
||||
string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT})
|
||||
|
||||
set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")
|
||||
|
||||
set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# message("Create header file for ${FILE_IN}")
|
||||
# message("Create header file for ${FILE_OUT}")
|
||||
file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
|
||||
|
||||
set(RESULT "#pragma once\n")
|
||||
string(APPEND RESULT "namespace cutlass {\n")
|
||||
string(APPEND RESULT "namespace nvrtc {\n")
|
||||
string(APPEND RESULT "${OUTPUT_STRING}")
|
||||
string(APPEND RESULT "} // namespace nvrtc\n")
|
||||
string(APPEND RESULT "} // namespace cutlass\n")
|
||||
file(WRITE "${FILE_OUT}" "${RESULT}")
|
||||
52
cmake/CTestTestfile.configure.cmake
Normal file
52
cmake/CTestTestfile.configure.cmake
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Generated file
|
||||
|
||||
set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)
|
||||
|
||||
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
|
||||
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
|
||||
endif()
|
||||
|
||||
foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
|
||||
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
|
||||
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
|
||||
return()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(TEST_EXE_PATH @TEST_EXE_PATH@)
|
||||
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
|
||||
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
43
cmake/CTestTestfile.test.configure.cmake
Normal file
43
cmake/CTestTestfile.test.configure.cmake
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
|
||||
# The longform/extended format allows generator expressions to be
|
||||
# expanded property and is useful in contexts where the files need
|
||||
# to be immediately included into being-processed cmake code.
|
||||
add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
else()
|
||||
add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
endif()
|
||||
|
||||
if (TEST_EXE_WORKING_DIRECTORY)
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
|
||||
9
cmake/NvidiaCutlassConfig.cmake.in
Normal file
9
cmake/NvidiaCutlassConfig.cmake.in
Normal file
@ -0,0 +1,9 @@
|
||||
get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
if(TARGET nvidia::cutlass::CUTLASS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
42
cmake/NvidiaCutlassPackageConfig.cmake
Normal file
42
cmake/NvidiaCutlassPackageConfig.cmake
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(CPACK_PACKAGE_NAME NvidiaCutlass)
|
||||
set(CPACK_PACKAGE_VENDOR NVIDIA)
|
||||
set(CPACK_PACKAGE_CONTACT info@nvidia.com)
|
||||
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CUTLASS CUDA C++ Template Linear Algebra Library")
|
||||
set(CPACK_PACKAGE_INSTALL_DIRECTORY ${CPACK_PACKAGE_NAME})
|
||||
set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR})
|
||||
set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR})
|
||||
set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH})
|
||||
set(CPACK_VERBATIM_VARIABLES YES)
|
||||
# set(CPACK_PACKAGE_DESCRIPTION_FILE ${CMAKE_CURRENT_LIST_DIR}/Description.txt)
|
||||
# set(CPACK_RESOURCE_FILE_WELCOME ${CMAKE_CURRENT_LIST_DIR}/Welcome.txt)
|
||||
# set(CPACK_RESOURCE_FILE_LICENSE ${CMAKE_CURRENT_LIST_DIR}/License.txt)
|
||||
# set(CPACK_RESOURCE_FILE_README ${CMAKE_CURRENT_LIST_DIR}/Readme.txt)
|
||||
include(CPack)
|
||||
52
cmake/googletest.cmake
Normal file
52
cmake/googletest.cmake
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
|
||||
|
||||
if(GOOGLETEST_DIR)
|
||||
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
|
||||
endif()
|
||||
|
||||
set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch")
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY ${GTEST_REPOSITORY}
|
||||
GIT_TAG v1.14.0
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(googletest)
|
||||
|
||||
if(NOT googletest_POPULATED)
|
||||
FetchContent_Populate(googletest)
|
||||
if (MSVC)
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
49
cmake/nop.cu
Normal file
49
cmake/nop.cu
Normal file
@ -0,0 +1,49 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Basic CUDA file for testing compiler flags.
|
||||
*/
|
||||
|
||||
__device__ int inner()
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
__global__ void test()
|
||||
{
|
||||
inner();
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
test<<<1,1>>>();
|
||||
return 0;
|
||||
}
|
||||
34
cmake/version_extended.h.in
Normal file
34
cmake/version_extended.h.in
Normal file
@ -0,0 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_REVISION "@CUTLASS_REVISION@"
|
||||
152
cuBLAS.cmake
Normal file
152
cuBLAS.cmake
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
|
||||
(DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED))
|
||||
|
||||
# Don't add cuBLAS if it's defined and false, assume it's not found.
|
||||
|
||||
set(CUBLAS_FOUND OFF)
|
||||
message(STATUS "cuBLAS Disabled.")
|
||||
|
||||
elseif(NOT TARGET cublas)
|
||||
|
||||
find_path(
|
||||
_CUBLAS_INCLUDE_DIR
|
||||
NAMES cublas_v2.h
|
||||
HINTS
|
||||
${CUBLAS_INCLUDE_PATH}
|
||||
ENV CUBLAS_INCLUDE_PATH
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
include
|
||||
)
|
||||
|
||||
find_library(
|
||||
_CUBLAS_LIBRARY
|
||||
NAMES cublas
|
||||
HINTS
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
|
||||
|
||||
message(STATUS "cuBLAS: ${_CUBLAS_LIBRARY}")
|
||||
message(STATUS "cuBLAS: ${_CUBLAS_INCLUDE_DIR}")
|
||||
|
||||
set(CUBLAS_FOUND ON CACHE INTERNAL "cublas Library Found")
|
||||
set(CUBLAS_LIBRARY ${_CUBLAS_LIBRARY})
|
||||
set(CUBLAS_INCLUDE_DIR ${_CUBLAS_INCLUDE_DIR})
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "cublas not found.")
|
||||
set(CUBLAS_FOUND OFF CACHE INTERNAL "cublas Library Found")
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_CUBLAS ${CUBLAS_FOUND} CACHE BOOL "Enable CUTLASS to build with cuBLAS library.")
|
||||
|
||||
if(CUTLASS_ENABLE_CUBLAS AND NOT CUBLAS_FOUND)
|
||||
message(FATAL_ERROR "CUTLASS_ENABLE_CUBLAS enabled but cuBLAS library could not be found.")
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublas STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublas SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cublas ALIAS cublas)
|
||||
|
||||
set_property(
|
||||
TARGET cublas
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUBLAS_LIBRARY})
|
||||
|
||||
target_include_directories(
|
||||
cublas
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUBLAS_INCLUDE_DIR}>)
|
||||
|
||||
find_library(
|
||||
_CUBLASLT_LIBRARY
|
||||
NAMES cublasLt
|
||||
HINTS
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublasLt STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublasLt SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
set_property(
|
||||
TARGET cublasLt
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${_CUBLASLT_LIBRARY})
|
||||
|
||||
add_library(nvidia::cublasLt ALIAS cublasLt)
|
||||
|
||||
target_link_libraries(cublas INTERFACE cublasLt)
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuBLAS ... done.")
|
||||
112
cuDNN.cmake
Normal file
112
cuDNN.cmake
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(DEFINED CUDNN_ENABLED)
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_ENABLED} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
endif()
|
||||
|
||||
if(DEFINED CUTLASS_ENABLE_CUDNN AND NOT CUTLASS_ENABLE_CUDNN)
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ...")
|
||||
|
||||
find_path(
|
||||
_CUDNN_INCLUDE_DIR cudnn.h
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include
|
||||
$ENV{CUDNN_PATH}/include
|
||||
$ENV{CUDA_PATH}/include
|
||||
${CUDNN_PATH}/include
|
||||
/usr/include)
|
||||
|
||||
find_library(
|
||||
_CUDNN_LIBRARY cudnn
|
||||
HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
$ENV{CUDNN_PATH}/lib64
|
||||
$ENV{CUDNN_PATH}/lib/x64
|
||||
$ENV{CUDNN_PATH}/lib
|
||||
$ENV{CUDA_PATH}/lib64
|
||||
$ENV{CUDA_PATH}/lib/x64
|
||||
$ENV{CUDA_PATH}/lib
|
||||
${CUDNN_PATH}/lib64
|
||||
${CUDNN_PATH}/lib/x64
|
||||
${CUDNN_PATH}/lib
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib)
|
||||
|
||||
if(_CUDNN_INCLUDE_DIR AND _CUDNN_LIBRARY)
|
||||
|
||||
message(STATUS "cuDNN: ${_CUDNN_LIBRARY}")
|
||||
message(STATUS "cuDNN: ${_CUDNN_INCLUDE_DIR}")
|
||||
|
||||
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "cuDNN not found.")
|
||||
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_FOUND} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
|
||||
if (CUTLASS_ENABLE_CUDNN AND NOT TARGET cudnn)
|
||||
|
||||
set(CUDNN_INCLUDE_DIR ${_CUDNN_INCLUDE_DIR})
|
||||
set(CUDNN_LIBRARY ${_CUDNN_LIBRARY})
|
||||
|
||||
if(WIN32)
|
||||
add_library(cudnn STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cudnn SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudnn ALIAS cudnn)
|
||||
|
||||
set_property(
|
||||
TARGET cudnn
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDNN_LIBRARY})
|
||||
|
||||
target_include_directories(
|
||||
cudnn
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)
|
||||
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_CUDNN AND NOT CUDNN_FOUND)
|
||||
message(FATAL_ERROR "CUTLASS_ENABLE_CUDNN enabled but cuDNN library could not be found.")
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ... done.")
|
||||
92
customConfigs.cmake
Normal file
92
customConfigs.cmake
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Profiler based functional testing
|
||||
set(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS OFF CACHE BOOL "Utilize profiler-based functional regressions")
|
||||
set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "Profiler functional regression test level")
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
function(cutlass_generate_kernel_filter_and_testlist_files)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs TEST_SET_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR}
|
||||
${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py
|
||||
--generator-target=${__TEST_SET_NAME}
|
||||
--cuda-version=${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}
|
||||
--architectures=${CUTLASS_NVCC_ARCHS}
|
||||
--kernels=\*
|
||||
--disable-cutlass-package-imports
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
|
||||
RESULT_VARIABLE cutlass_FILTER_GENERATION_RESULT
|
||||
OUTPUT_VARIABLE cutlass_FILTER_GENERATION_OUTPUT
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
)
|
||||
|
||||
if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Error generating kernel filters and testlist files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0)
|
||||
|
||||
message(STATUS "Building for L0 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l0)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1)
|
||||
|
||||
message(STATUS "Building for L1 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l1)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -1,380 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates wrapping direct issue of MMA instructions to Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specifies internal data type for computation
|
||||
struct ComputeType {
|
||||
enum Kind {
|
||||
kBegin,
|
||||
kDefault, /// Compute type implied by operand and accumulator types
|
||||
kEnd
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Direct wrapper for native MMA instruction
|
||||
template <
|
||||
/// Warp-level matrix multiply-accumulate operation
|
||||
typename WmmaTile,
|
||||
/// Layout of A multiplicand
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarA,
|
||||
/// Layout of B multiplicand
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarB,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC,
|
||||
/// Specifies particular compute type, overriding data types of operands
|
||||
ComputeType::Kind ComputeTy>
|
||||
inline __device__ void mma(ScalarA const A[], ScalarB const B[], ScalarC const C[], ScalarC D[]);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// 16x16x4
|
||||
//
|
||||
|
||||
//
|
||||
// FP16 accumulation
|
||||
//
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// FP32 accumulation
|
||||
//
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile ("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
@ -1,102 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Defines conversion operations among Fragments of different base type.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputFragment_>
|
||||
struct Convert {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputScalar_, int kScalars_>
|
||||
struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
|
||||
/// The input fragment.
|
||||
typedef Fragment<InputScalar_, kScalars_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<OutputScalar_, kScalars_> OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Convert() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
for (int i = 0; i < kScalars_; ++i) {
|
||||
dst[i] = static_cast<OutputScalar_>(src[i + offset]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_>
|
||||
struct Copy {
|
||||
/// The input fragment.
|
||||
typedef Fragment_ InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment_ OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Copy() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename InputFragment_>
|
||||
CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
|
||||
if (sizeof(typename Fragment_::Element) == 8) {
|
||||
uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
|
||||
uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
} else {
|
||||
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
403
cutlass/coord.h
403
cutlass/coord.h
@ -1,403 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief A Coord is a coordinate of arbitrary rank into a tensor or matrix
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Describes identity elements
|
||||
struct Identity {
|
||||
/// Enumeration describing identity elements. Value assignments are significant.
|
||||
/// Feel free to add or multiply by these, respectively.
|
||||
enum Kind { Additive = 0, Multiplicative = 1 };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically-sized array specifying Coords within a tensor
|
||||
template <int Rank_, typename Index_ = int>
|
||||
struct Coord {
|
||||
//
|
||||
// Type and constant definitions
|
||||
//
|
||||
|
||||
/// Number of elements in Coord
|
||||
static int const kRank = Rank_;
|
||||
|
||||
/// Number of elements in Coord, aliased for compatibility
|
||||
static int const N = Rank_;
|
||||
|
||||
/// Index type used to store elements
|
||||
typedef Index_ Index;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Indices
|
||||
Index idx[kRank];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor initializes uniformly
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Index value = 0) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Index _idx[]) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = _idx[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Coord<kRank> const &coord) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = coord[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a slice of the Coord which may be larger or smaller in rank
|
||||
/// than this.
|
||||
template <int Slice>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Slice> slice(int start = 0, Index identity = 0) const {
|
||||
Coord<Slice> result;
|
||||
for (int i = 0; i < Slice; ++i) {
|
||||
if (i + start < kRank) {
|
||||
result[i] = idx[i + start];
|
||||
}
|
||||
else {
|
||||
result[i] = identity;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns true if Coord is non-zero.
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator bool() const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (idx[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if Coord is uniformly zero.
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!() const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (idx[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator+(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] + b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator-(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] - b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator*(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] * b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator/(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] / b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator+=(Coord const& b) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] += b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator-=(Coord const& b) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] -= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator*=(Coord const& b) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] *= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator/=(Coord const& b) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] /= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; }
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; }
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
|
||||
T sum = T(0);
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE Index& at() {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index& at(int dim) { return idx[dim]; }
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE Index const& at() const {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const& at(int dim) const { return idx[dim]; }
|
||||
|
||||
/// Determines if two Coord<> objects are equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Coord<kRank> const& b) const {
|
||||
bool equal = true;
|
||||
for (int i = 0; equal && i < kRank; ++i) {
|
||||
equal = (idx[i] == b.idx[i]);
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
/// Not equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Coord<kRank> const& b) const { return !(*this == b); }
|
||||
|
||||
/// Clamps a coordinate to a range specified by maximum and minimum values
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& clamp(Coord<kRank> const& max, Coord<kRank> const& min = Coord<kRank>()) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the product of all elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index count() const {
|
||||
Index product = idx[0];
|
||||
for (int i = 1; i < kRank; ++i) {
|
||||
product *= idx[i];
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
/// Less than operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<(Coord<kRank> const &b) const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (!(idx[i] < b[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Less than or equals operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<=(Coord<kRank> const &b) const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (!(idx[i] <= b[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(T s, Coord<Rank, Index> coord) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, T s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar division
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator/(T s, Coord<Rank, Index> coord) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] = s / coord[i];
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar division
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator/(Coord<Rank, Index> coord, T s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] /= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Integer-valued make_Coord
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<1> make_Coord(int _0) {
|
||||
int values[1] = {_0};
|
||||
return Coord<1>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> make_Coord(int _0, int _1) {
|
||||
int values[2] = {_0, _1};
|
||||
return Coord<2>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 3-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> make_Coord(int _0, int _1, int _2) {
|
||||
int values[3] = {_0, _1, _2};
|
||||
return Coord<3>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 4-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
int values[4] = {_0, _1, _2, _3};
|
||||
return Coord<4>(values);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_>
|
||||
CUTLASS_HOST_DEVICE Coord<3> make_Coord_from_shape() {
|
||||
return make_Coord(Shape_::kD, Shape_::kH, Shape_::kW);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,126 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Helpers for printing cutlass/core objects
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Rank>
|
||||
std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
out << (i ? ", " : "") << coord.idx[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
|
||||
template <typename T>
|
||||
struct ScalarIO {
|
||||
|
||||
/// Value to print
|
||||
T value;
|
||||
|
||||
/// Default ctor
|
||||
ScalarIO() { }
|
||||
|
||||
/// Constructs from a value
|
||||
ScalarIO(T value): value(value) {}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Default printing to ostream
|
||||
template <typename T>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<T> const &scalar) {
|
||||
return out << scalar.value;
|
||||
}
|
||||
|
||||
/// Printing to ostream of int8_t as integer rather than character
|
||||
template <>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<int8_t> const &scalar) {
|
||||
return out << int(scalar.value);
|
||||
}
|
||||
|
||||
/// Printing to ostream of uint8_t as integer rather than character
|
||||
template <>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scalar) {
|
||||
return out << unsigned(scalar.value);
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 1b elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::bin1_t, 32> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 32; i++) {
|
||||
out << int(scalar.value[i]);
|
||||
out << ((i != 31) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 4b signed integer elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::int4_t, 8> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out << int(scalar.value[i]);
|
||||
out << ((i != 7) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 4b unsigned integer elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::uint4_t, 8> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out << unsigned(scalar.value[i]);
|
||||
out << ((i != 7) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,105 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Basic include for CUTLASS macros
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 3
|
||||
#define CUTLASS_PATCH 2
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#else
|
||||
#define CUTLASS_HOST_DEVICE
|
||||
// CUTLASS_DEVICE is an error if not compiling device code
|
||||
#endif
|
||||
|
||||
// CUDA 10.1 introduces the mma instruction
|
||||
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
|
||||
#endif
|
||||
|
||||
// CUTLASS assert
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#ifdef __NVCC__
|
||||
#define CUTLASS_PRAGMA_UNROLL #pragma unroll
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1
|
||||
#elif defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
||||
#endif
|
||||
|
||||
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
|
||||
|
||||
#define CUTLASS_GEMM_LOOP_HEADER \
|
||||
asm volatile (".pragma \"nounroll\";\n");
|
||||
#else
|
||||
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#define CUTLASS_GEMM_LOOP_HEADER
|
||||
#define CUTLASS_GEMM_LOOP
|
||||
|
||||
#endif
|
||||
|
||||
// A small helper class to dump a type at compile time
|
||||
// Usage:: DumpType<Class>::Class
|
||||
template <typename T>
|
||||
struct DebugType {};
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
void DebugTypeFunc(T const& t) {
|
||||
T::t;
|
||||
}
|
||||
|
||||
// A small helper class to dump a compile time constant at compile time
|
||||
// Usage: DumpValue<Class::kConstant>::kConstant
|
||||
template <int Value>
|
||||
struct DebugValue {};
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// NVIDIA GPU Warp size
|
||||
static const int kWarpSize = 32;
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,280 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines Fragment, a statically-sized array for storing parts of matrices within a
|
||||
thread's registers.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/util/cutlass_math.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_concept Fragment Concept
|
||||
@{
|
||||
|
||||
\ref fragment_concept is a statically sized array for storing parts of tiles held by individual CUDA
|
||||
threads.
|
||||
|
||||
@par \ref fragment_concept
|
||||
Types satisfying \ref fragment_concept define the following members
|
||||
- <b>Element</b> - type of each access held within the fragment
|
||||
- <b>kElements</b> - number of elements stored by the fragment
|
||||
- <b>clear()</b> - overwrites the fragment storage with zeros
|
||||
- <b>Element & operator[](int i)</b> - by-reference access of the ith element
|
||||
- <b>Element const & operator[](int i) const</b> - const by-reference access of the ith element
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_iterator_concept Fragment Iterator Concept
|
||||
@{
|
||||
|
||||
\ref fragment_iterator_concept provides structured access to the elements within a fragment with an
|
||||
optional bitcast to the desired access type
|
||||
|
||||
@par \ref fragment_iterator_concept
|
||||
Types satisfying \ref fragment_iterator_concept define the following members
|
||||
- <b>AccessType& operator[](int i)</b> - provides access to the ith element of the fragment
|
||||
- <b>AccessType& at(int d, int h, int w, int c)</b> - applies \ref layout_concept to fragment and
|
||||
provides access to element at (d, h, w, c)
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int alignment>
|
||||
struct StorageType {
|
||||
typedef uint64_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<4> {
|
||||
typedef uint32_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<2> {
|
||||
typedef uint16_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<1> {
|
||||
typedef uint8_t Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_concept
|
||||
* @concept{fragment_concept}
|
||||
*/
|
||||
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
|
||||
struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
/// Make sure the alignment makes sense wrt the size of elements.
|
||||
static_assert(int(kAlignment_) == 16 || int(kAlignment_) >= sizeof(Element_), "Alignment is too small");
|
||||
/// Alignment must be a power of two
|
||||
static_assert(is_pow2<int(kAlignment_)>::value, "Alignment must be a power of two");
|
||||
|
||||
/// This class.
|
||||
typedef Fragment<Element_, kElements_> This_;
|
||||
/// The element.
|
||||
typedef Element_ Element;
|
||||
/// The number of elements.
|
||||
static int const kElements = kElements_;
|
||||
/// Alignment
|
||||
static int const kAlignment = int(kAlignment_);
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
// Avoid element-wise access for sub 32b element type
|
||||
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
|
||||
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
|
||||
ptr[i] = uint64_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
|
||||
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
|
||||
ptr[i] = uint32_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
|
||||
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
|
||||
ptr[i] = uint16_t(0);
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElements; ++i) {
|
||||
storage[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast<Element*>(storage)[i]; }
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
|
||||
return reinterpret_cast<Element const*>(storage)[i];
|
||||
}
|
||||
|
||||
private:
|
||||
/// Storage type to use for Elements
|
||||
typedef typename StorageType<int(kAlignment_)>::Type StorageType;
|
||||
|
||||
/// Number of elements in the storage
|
||||
static int const kStorageCount =
|
||||
(sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
|
||||
/// The storage.
|
||||
StorageType storage[kStorageCount];
|
||||
|
||||
/// Ensure that there's enough storage for all elements
|
||||
static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_iterator_concept
|
||||
* @concept{fragment_iterator_concept}
|
||||
*/
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType& operator[](int i) {
|
||||
return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentConstIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
/// Create from non-constant FragmentIterator
|
||||
CUTLASS_HOST_DEVICE FragmentConstIterator(
|
||||
FragmentIterator<Fragment_, Iterations_, AccessType_> const& rhs_)
|
||||
: pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element const* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,155 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines multiply-add operations on fragments within a thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template < typename ScalarAlphaBeta_,
|
||||
typename ScalarAccum_,
|
||||
bool fragMul2 = true /*number of element per fragment is multiple of 2*/
|
||||
>
|
||||
struct FragmentMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for alpha and beta
|
||||
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
|
||||
/// The type for accumlator
|
||||
typedef ScalarAccum_ ScalarAccum;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += b[j * kReduction + k];
|
||||
}
|
||||
d[j] = a * ScalarAlphaBeta(d[j]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += b[j * kReduction + k];
|
||||
}
|
||||
d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
template <>
|
||||
struct FragmentMultiplyAdd<half, half, true> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for alpha and beta
|
||||
typedef half ScalarAlphaBeta;
|
||||
/// The type for accumlator
|
||||
typedef half ScalarAccum;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// The input.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
|
||||
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(half a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// The inputs.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
|
||||
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,58 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for efficiently clearing accumulator tiles.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_ = 1>
|
||||
struct ClearAccumulators {
|
||||
/// The shared storage.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
|
||||
/// Clear the fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void clear(Fragment_& fragment) {
|
||||
fragment.clear();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,70 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief device level GEMM implemented by more than one kernels.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template<typename DeviceGemmTraits_ >
|
||||
struct DeviceGemm {
|
||||
/// The Traits
|
||||
typedef DeviceGemmTraits_ Traits;
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
/// Support for NVRTC
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernels in order
|
||||
static __host__ cudaError_t launch(Params const& params) {
|
||||
//Traits::GemmTraits::KernelClass::launch(params.GemmParams);
|
||||
Gemm<typename Traits::GemmTraits>::launch(params.GemmParams);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
return err;
|
||||
Traits::ReductionTraits::KernelClass::launch(params.ReductionParams);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
#endif
|
||||
|
||||
///
|
||||
/// Methods
|
||||
///
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE DeviceGemm() {}
|
||||
};
|
||||
} // namespace device_gemm
|
||||
} // namespace cutalss
|
||||
@ -1,174 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include "cutlass/gemm/device_gemm.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "tools/util/type_traits.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template <
|
||||
/// The Tratis for the first kernel
|
||||
typename GemmTraits_,
|
||||
/// The Traits for the second kernel
|
||||
typename ReductionTraits_
|
||||
>
|
||||
struct SplitkPIGemmTraits {
|
||||
typedef GemmTraits_ GemmTraits;
|
||||
typedef ReductionTraits_ ReductionTraits;
|
||||
typedef SplitkPIGemmTraits<GemmTraits_, ReductionTraits_> This_;
|
||||
typedef typename cutlass::gemm::DeviceGemm<This_> KernelClass;
|
||||
|
||||
///
|
||||
typedef typename GemmTraits::Index Index;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarAlphaBeta Scalar;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarA ScalarA;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarB ScalarB;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarD ScalarAccum;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarC ScalarC;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarD ScalarD;
|
||||
/// The layout of A. can be deduced from the layout set in batched gemm
|
||||
static MatrixLayout::Kind const kLayoutA = GemmTraits::kLayoutA;
|
||||
/// The layout of B. can be deduced from the layout set in batched gemm
|
||||
static MatrixLayout::Kind const kLayoutB = GemmTraits::kLayoutB;
|
||||
|
||||
struct Params {
|
||||
/// The dimensions of the GEMM in K, N, M order
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// Check if params are init
|
||||
bool problem_size_initialized;
|
||||
/// The pointer to workspace memory
|
||||
ScalarAccum *workspace_ptr;
|
||||
///
|
||||
size_t workspace_size;
|
||||
/// The Params for the first kernel
|
||||
typename GemmTraits::Params GemmParams;
|
||||
/// The Params for the second kernel
|
||||
typename ReductionTraits::Params ReductionParams;
|
||||
|
||||
/// ctor
|
||||
Params() :
|
||||
workspace_size(0),
|
||||
problem_size_initialized(false) {}
|
||||
/// ctor
|
||||
Params(Index m_,
|
||||
Index n_,
|
||||
Index k_
|
||||
):
|
||||
problem_size(k_, n_, m_, 1),
|
||||
workspace_size(0),
|
||||
problem_size_initialized(true) {
|
||||
|
||||
}
|
||||
|
||||
/// init problem is needed if using default ctor
|
||||
void init_problem(Index m_,
|
||||
Index n_,
|
||||
Index k_){
|
||||
problem_size = GemmCoord(k_, n_, m_, 1);
|
||||
problem_size_initialized = true;
|
||||
}
|
||||
|
||||
int initialize(Scalar alpha_,
|
||||
ScalarA const* d_a_,
|
||||
Index lda_,
|
||||
ScalarB const* d_b_,
|
||||
Index ldb_,
|
||||
Scalar beta_,
|
||||
ScalarC const* d_c_,
|
||||
Index ldc_,
|
||||
ScalarD* d_d_,
|
||||
Index ldd_,
|
||||
ScalarAccum *workspace_ptr_,
|
||||
Index partitionK_multiple = 1) {
|
||||
|
||||
workspace_ptr = workspace_ptr_;
|
||||
|
||||
//call GemmTraits (first kernel) param
|
||||
//for the first kernel A is A, B is B, C and D are workspace
|
||||
//alpha is one, beta is zero, partitionK_count is reductionTraits::reductionSize
|
||||
typename cutlass::gemm::GemmDesc<typename GemmTraits::ScalarA,
|
||||
typename GemmTraits::ScalarB,
|
||||
typename GemmTraits::ScalarC,
|
||||
typename GemmTraits::ScalarD,
|
||||
typename GemmTraits::Epilogue::Scalar>
|
||||
desc(
|
||||
problem_size,
|
||||
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(1.0f), /*alpha*/
|
||||
TensorRef<typename GemmTraits::ScalarA const, 2>(d_a_, lda_),
|
||||
TensorRef<typename GemmTraits::ScalarB const, 2>(d_b_, ldb_),
|
||||
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(0.0f), /*beta*/
|
||||
TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
|
||||
TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
|
||||
);
|
||||
GemmParams.initialize(desc, ReductionTraits::ReductionSize, partitionK_multiple);
|
||||
|
||||
|
||||
//call batched reduction (second kernel) param
|
||||
ReductionParams.initialize(problem_size.m(), /*m*/
|
||||
problem_size.n(), /*n*/
|
||||
alpha_, /*alpha*/
|
||||
beta_, /*beta*/
|
||||
problem_size.n() * problem_size.m() /*reduction_stride*/,
|
||||
workspace_ptr,
|
||||
problem_size.m(),
|
||||
d_c_,
|
||||
ldc_,
|
||||
d_d_,
|
||||
ldd_);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
|
||||
// note typedef typename GemmTraits::ScalarD ScalarAccum;
|
||||
// workspace of size of M * N * Reduction
|
||||
size_t required_workspace_memory_in_byte(){
|
||||
assert(problem_size_initialized == true);
|
||||
workspace_size = static_cast<size_t>(problem_size.n()) *
|
||||
static_cast<size_t>(problem_size.m()) *
|
||||
static_cast<size_t>(ReductionTraits::ReductionSize) *
|
||||
sizeof(ScalarAccum);
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace device_gemm
|
||||
} // namespace cutalss
|
||||
@ -1,134 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural traits of double-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for threadblock-level GEMM (K-by-N-by-M).
|
||||
typename OutputTile_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct DgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
double,
|
||||
/// The scalar type for B.
|
||||
double,
|
||||
/// The scalar type for C.
|
||||
double,
|
||||
/// The scalar type for D.
|
||||
double,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, double, double, double>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
2,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
2,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
2,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
false,
|
||||
/// kLaunchBounds
|
||||
false
|
||||
>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for threadblock-level GEMM (K-by-N-by-M)
|
||||
typename OutputTile_ = Shape<8, 64, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<double>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of doubles loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of doubles loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The DGEMM config.
|
||||
typename GemmConfig_ =
|
||||
DgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct DgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,86 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template implementing matrix multiply-add operations on fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename ThreadGemmShape_,
|
||||
typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The shape of a thread-leveel matrix multiply accumulate.
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A. specialized to half
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B. specialized to half
|
||||
typedef half ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D. specialized to float
|
||||
typedef float ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,152 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of single-precision GEMM where any number of the input/output
|
||||
could be fp16 or fp32. The accumulator type stays in fp32
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/fp16_sgemm_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The type for A
|
||||
typename ScalarA_,
|
||||
/// The type for B
|
||||
typename ScalarB_,
|
||||
/// The type for C
|
||||
typename ScalarC_,
|
||||
/// The type for D
|
||||
typename ScalarD_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct Fp16SgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
ScalarA_,
|
||||
/// The scalar type for B.
|
||||
ScalarB_,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The type for A
|
||||
typename ScalarA_ = half,
|
||||
/// The type for B
|
||||
typename ScalarB_ = half,
|
||||
/// The type for C
|
||||
typename ScalarC_ = half,
|
||||
/// The type for D
|
||||
typename ScalarD_ = half,
|
||||
/// the Type for alpha and beta,
|
||||
typename Scalar_ = half,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<Scalar_, FragmentMultiplyAdd<Scalar_, float/*accumulator type*/> >,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
Fp16SgemmConfig<OutputTile_,
|
||||
ThreadGemmShape_,
|
||||
ScalarA_,
|
||||
ScalarB_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct Fp16SgemmSgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,250 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM kernel with launch bounds specified
|
||||
template <typename Gemm_>
|
||||
__global__ __launch_bounds__(Gemm_::kThreads)
|
||||
void gemm_kernel(typename Gemm_::Params params) {
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int GemmSharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename Gemm_::SharedStorage *shared_storage =
|
||||
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, *shared_storage);
|
||||
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM kernel without launch bounds specified
|
||||
template <typename Gemm_>
|
||||
__global__ /* __launch_bounds__(Gemm_::kThreads) */
|
||||
void gemm_kernel_nolb(typename Gemm_::Params params) {
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int GemmSharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename Gemm_::SharedStorage *shared_storage =
|
||||
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, *shared_storage);
|
||||
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Partial specialization for launching the GEMM kernel with or without launch bounds
|
||||
template <typename Gemm, bool WithLaunchBounds>
|
||||
struct Launch {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
gemm_kernel<Gemm>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
gemm_kernel<Gemm><<< grid, block, sizeof(typename Gemm::SharedStorage), stream >>>(params);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for launching the GEMM kernel with or without launch bounds
|
||||
template <typename Gemm>
|
||||
struct Launch<Gemm, false> {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
int smem_size = int(sizeof(typename Gemm::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// throw exception?
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
gemm_kernel_nolb<Gemm><<<
|
||||
grid,
|
||||
block,
|
||||
smem_size,
|
||||
stream >>>(params);
|
||||
}
|
||||
|
||||
// Use device API to launch kernel
|
||||
Launch(cudaError_t &result, CUfunction kernel,
|
||||
typename Gemm::Params params, dim3 grid, dim3 block, CUstream stream = CU_STREAM_LEGACY) {
|
||||
void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(¶ms))};
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
kernel,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
kernel,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
CUresult launch_result = cuLaunchKernel(
|
||||
kernel,
|
||||
grid.x, grid.y, grid.z,
|
||||
block.x, block.y, block.z,
|
||||
smem_size, stream, params_, 0);
|
||||
|
||||
if (launch_result != CUDA_SUCCESS) {
|
||||
result = cudaErrorLaunchFailure;
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaSuccess;
|
||||
return;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Traits_>
|
||||
struct Gemm {
|
||||
|
||||
/// The traits.
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
typedef typename Traits::KernelClass KernelClass;
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
||||
/// Support for NVRTC
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
|
||||
// Launch the kernel.
|
||||
Launch<KernelClass, Traits::GemmConfig::kLaunchBounds>(
|
||||
params, params.grid, params.block, stream);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(CUfunction kernel,
|
||||
Params const& params,
|
||||
CUstream stream = CU_STREAM_LEGACY) {
|
||||
cudaError_t result;
|
||||
|
||||
// Launch the kernel.
|
||||
Launch<KernelClass, Traits::GemmConfig::kLaunchBounds>(
|
||||
result, kernel, params, params.grid, params.block, stream);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,145 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines properties of GEMM computation that impose some constraints on caller.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The scalar type for A.
|
||||
typename ScalarA_,
|
||||
/// The scalar type for B.
|
||||
typename ScalarB_,
|
||||
/// The scalar type for C.
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D.
|
||||
typename ScalarD_,
|
||||
/// The threadblock tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math.
|
||||
typename MultiplyAdd_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
int kScalarsPerStsA_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
int kScalarsPerStsB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int kScalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
bool kResidueSeparate_ = false,
|
||||
/// Is residue performed in prologue?
|
||||
bool kResidueInProlog_ = false,
|
||||
/// If true, kernel is launched with CUDA launch bounds specified
|
||||
bool kLaunchBounds_ = true>
|
||||
struct GemmConfig {
|
||||
//
|
||||
/// The scalar for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// The tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The functor to do D = A*B + C.
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
/// The shape of the instruction.
|
||||
typedef typename MultiplyAdd::InstructionShape InstructionShape;
|
||||
/// The shape of warp-level GEMM
|
||||
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
|
||||
/// The accumulators.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// The number of warps.
|
||||
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
|
||||
/// The default warp size (32 threads per warp).
|
||||
static int const kWarpSize = cutlass::kWarpSize;
|
||||
/// The numnber of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
|
||||
static int const kScalarsPerStsA = kScalarsPerStsA_;
|
||||
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
|
||||
static int const kScalarsPerStsB = kScalarsPerStsB_;
|
||||
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
|
||||
|
||||
/// The number of scalars per LDG for C.
|
||||
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
|
||||
|
||||
/// The number of scalars per STS/LDS/STG for D.
|
||||
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
|
||||
static int const kScalarsPerStsD = kScalarsPerStsD_;
|
||||
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
|
||||
|
||||
/// The number of accumulators that are going to be fed from one LDS A/B.
|
||||
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
|
||||
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// If true, mainloop is instantiated twice. The first instantiation contains no predicate
|
||||
// updates and is more efficient for some kernels. If false, only a single mainloop is
|
||||
// instantaited.
|
||||
static bool const kResidueSeparate = kResidueSeparate_;
|
||||
|
||||
/// If true, residue is computed in the prologue.
|
||||
static bool const kResidueInProlog = kResidueInProlog_;
|
||||
|
||||
/// If true, kernel is launched with launch bounds specified
|
||||
static bool const kLaunchBounds = kLaunchBounds_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,209 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GemmCoord is a structure derived from Coord<4> that specifies a location within the
|
||||
coordinate system of a GEMM problem.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GemmCoord is a structure derived from Coord<4> that specifies a location within the
|
||||
/// coordinate space of a GEMM problem.
|
||||
struct GemmCoord : public Coord<4, int> {
|
||||
|
||||
/// Integer-valued index
|
||||
typedef int Index;
|
||||
|
||||
/// Base type is a Coord of rank=4
|
||||
typedef Coord<4, Index> Base;
|
||||
|
||||
/// GEMM K dimension - inner dimension of the GEMM problem
|
||||
static int const kK = 0;
|
||||
|
||||
/// GEMM N dimension - columns of the output C matrix
|
||||
static int const kN = 1;
|
||||
|
||||
/// GEMM M dimension - rows of the output C matrix
|
||||
static int const kM = 2;
|
||||
|
||||
/// Batch dimension - for generalizing to larger problems
|
||||
static int const kBatch = 3;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord() { }
|
||||
|
||||
/// Constructs from Coord<3> and a batch
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { }
|
||||
|
||||
/// Constructs from Coord<4>
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Coord<4, Index> const &coord): Base(coord) { }
|
||||
|
||||
/// Constructs from an array of coordinate elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Index coord[4]): Base(coord) { }
|
||||
|
||||
/// Helper to construct from a K, N, M, batch variables
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Index k, Index n, Index m, Index batch = 0): Base(make_Coord(k, n, m, batch)) { }
|
||||
|
||||
/// Returns the GEMM M coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & m() const { return this->at(kM); }
|
||||
|
||||
/// Returns reference to the GEMM M coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & m() { return this->at(kM); }
|
||||
|
||||
/// Returns the GEMM N coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & n() const { return this->at(kN); }
|
||||
|
||||
/// Returns reference to the GEMM N coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & n() { return this->at(kN); }
|
||||
|
||||
/// Returns the GEMM K coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & k() const { return this->at(kK); }
|
||||
|
||||
/// Returns reference to the GEMM K coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & k() { return this->at(kK); }
|
||||
|
||||
/// Returns the GEMM batch coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & batch() const { return this->at(kBatch); }
|
||||
|
||||
/// Returns reference to the GEMM batch coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & batch() { return this->at(kBatch); }
|
||||
|
||||
/// Obtains a Coord<3> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> knm() const {
|
||||
return make_Coord(k(), n(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> nm() const {
|
||||
return make_Coord(n(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> mn() const {
|
||||
return make_Coord(m(), n());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> km() const {
|
||||
return make_Coord(k(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> kn() const {
|
||||
return make_Coord(k(), n());
|
||||
}
|
||||
|
||||
//
|
||||
// Coord operators
|
||||
//
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator+(Base const& b) const {
|
||||
return GemmCoord(Base::operator+(b));
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator-(Base const& b) const {
|
||||
return GemmCoord(Base::operator-(b));
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator*(Base const& b) const {
|
||||
return GemmCoord(Base::operator*(b));
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator/(Base const& b) const {
|
||||
return GemmCoord(Base::operator/(b));
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator+=(Base const& b) {
|
||||
Base::operator+=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator-=(Base const& b) {
|
||||
Base::operator-=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator*=(Base const& b) {
|
||||
Base::operator*=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator/=(Base const& b) {
|
||||
Base::operator/=(b);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,205 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
/// GEMM problem description
|
||||
template <
|
||||
/// Source accumulator matrix type
|
||||
typename AType_,
|
||||
/// Destination accumulator type
|
||||
typename BType_,
|
||||
/// Source accumulator matrix type
|
||||
typename CType_,
|
||||
/// Destination accumulator type
|
||||
typename DType_,
|
||||
/// Scalar type for alpha and beta
|
||||
typename SType_,
|
||||
/// Index type for dimensions and strides
|
||||
typename Index_ = int
|
||||
> struct GemmDesc {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Index type for dimensions and strides
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Source accumulator matrix type
|
||||
typedef AType_ AType;
|
||||
|
||||
/// Tensor reference to A operand
|
||||
typedef TensorRef<AType const, 2> TensorRefA;
|
||||
|
||||
/// Destination accumulator type
|
||||
typedef BType_ BType;
|
||||
|
||||
/// Tensor reference to B operand
|
||||
typedef TensorRef<BType const, 2> TensorRefB;
|
||||
|
||||
/// Source accumulator matrix type
|
||||
typedef CType_ CType;
|
||||
|
||||
/// Tensor reference to C operand
|
||||
typedef TensorRef<CType const, 2> TensorRefC;
|
||||
|
||||
/// Destination accumulator type
|
||||
typedef DType_ DType;
|
||||
|
||||
/// Tensor reference to D operand
|
||||
typedef TensorRef<DType, 2> TensorRefD;
|
||||
|
||||
/// Scalar type for alpha and beta
|
||||
typedef SType_ SType;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The alpha scaling values.
|
||||
SType alpha;
|
||||
|
||||
/// The source matrix A.
|
||||
TensorRefA A;
|
||||
|
||||
/// batch stride for A operand
|
||||
long long batch_stride_A;
|
||||
|
||||
/// The source matrix B.
|
||||
TensorRefB B;
|
||||
|
||||
/// batch stride for B operand
|
||||
long long batch_stride_B;
|
||||
|
||||
/// The beta scaling values.
|
||||
SType beta;
|
||||
|
||||
/// The source matrix C.
|
||||
TensorRefC C;
|
||||
|
||||
/// batch stride for C operand
|
||||
long long batch_stride_C;
|
||||
|
||||
/// The destination matrix D.
|
||||
TensorRefD D;
|
||||
|
||||
/// batch stride for D operand
|
||||
long long batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {}
|
||||
|
||||
/// Constructor for basic GEMM with batch count = 1
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(Coord<3> _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
TensorRefB const &_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
TensorRefD const &_D
|
||||
):
|
||||
problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(0),
|
||||
B(_B),
|
||||
batch_stride_B(0),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(0),
|
||||
D(_D),
|
||||
batch_stride_D(0) {}
|
||||
|
||||
/// Constructor for basic GEMM with batch count = 1
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(GemmCoord _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
TensorRefB const &_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
TensorRefD const &_D
|
||||
):
|
||||
problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(0),
|
||||
B(_B),
|
||||
batch_stride_B(0),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(0),
|
||||
D(_D),
|
||||
batch_stride_D(0) {
|
||||
|
||||
assert(_problem_size.batch() == 1);
|
||||
}
|
||||
|
||||
/// Constructor for strided batch GEMM GEMM
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(GemmCoord _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
long long _batch_stride_A,
|
||||
TensorRefB const &_B,
|
||||
long long _batch_stride_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
long long _batch_stride_C,
|
||||
TensorRefD const &_D,
|
||||
long long _batch_stride_D
|
||||
):
|
||||
problem_size(_problem_size),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(_batch_stride_A),
|
||||
B(_B),
|
||||
batch_stride_B(_batch_stride_B),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(_batch_stride_C),
|
||||
D(_D),
|
||||
batch_stride_D(_batch_stride_D) {}
|
||||
};
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,223 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with
|
||||
the computed matrix product.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct GemmEpilogue {
|
||||
/// The traits class.
|
||||
typedef GemmEpilogueTraits_ Traits;
|
||||
/// The params.
|
||||
typedef typename Traits::Params Params;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename Traits::OutputTile OutputTile;
|
||||
/// The number of iterations.
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
/// The accumulators.
|
||||
typedef typename Traits::Accumulators Accumulators;
|
||||
/// The scalar.
|
||||
typedef typename Traits::Scalar Scalar;
|
||||
/// The functor in charge of the math.
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The iterator for C in global memory.
|
||||
typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef typename Traits::GlobalTransformerC GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef typename Traits::GlobalTransformerD GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
|
||||
/// The iterator to load D in shared memory.
|
||||
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
|
||||
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Coord<3> const& _problem_size)
|
||||
: params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& accumulators,
|
||||
Coord<3> const& block = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
if (functor.source_required()) {
|
||||
epilogue_with_or_without_beta<true>(accumulators, block, batch_id);
|
||||
} else {
|
||||
epilogue_with_or_without_beta<false>(accumulators, block, batch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kSourceRequired>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators,
|
||||
Coord<3> const& block,
|
||||
int batch_id) {
|
||||
// The C fragment.
|
||||
typename GlobalLoadIteratorC::Fragment fragment_c;
|
||||
// The transformed C fragment.
|
||||
typename GlobalTransformerC::OutputFragment transformed_c;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
// Compute pointer and predicate offsets for C and D global iterators.
|
||||
int const pointer_offset =
|
||||
((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.inc_advance) *
|
||||
Iterations::kW +
|
||||
params.stride_h) *
|
||||
h;
|
||||
|
||||
int const predicate_offset =
|
||||
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.predicate_inc_advance) *
|
||||
Iterations::kW +
|
||||
Traits::Delta::kH) *
|
||||
h;
|
||||
|
||||
// The iterator to load the elements of the C matrix.
|
||||
GlobalLoadIteratorC global_load_iterator(
|
||||
params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update C pointer offset based on batch_id and batch_stride_offset
|
||||
global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_C);
|
||||
|
||||
// The transformer for C.
|
||||
GlobalTransformerC transformer_c;
|
||||
// The transformer for D.
|
||||
GlobalTransformerD transformer_d;
|
||||
|
||||
// The iterator to store into the D matrix.
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update D pointer offset based on batch_id and batch_stride_offset
|
||||
global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_D);
|
||||
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
|
||||
SharedStoreIteratorD shared_store_iterator(
|
||||
params.shared_store_iterator_d,
|
||||
reinterpret_cast<typename SharedStoreIteratorD::Scalar*>(shared_storage.data()));
|
||||
|
||||
SharedLoadStreamD shared_load_stream(
|
||||
params.shared_load_stream_d,
|
||||
reinterpret_cast<typename SharedLoadStreamD::Scalar*>(shared_storage.data()));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Load the C matrix into fragment.
|
||||
if (kSourceRequired) {
|
||||
global_load_iterator.load_post_increment(fragment_c);
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
shared_load_fence();
|
||||
|
||||
// Copy the accumulators to shared memory.
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
|
||||
shared_store_iterator.store_post_increment(shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Copy the accumulators back to registers from shared memory.
|
||||
shared_load_stream.copy();
|
||||
shared_load_stream.commit();
|
||||
|
||||
// Do the math.
|
||||
typename GlobalTransformerD::InputFragment fragment_d;
|
||||
if (kSourceRequired) {
|
||||
// Transform C fragment.
|
||||
transformer_c.transform(fragment_c, transformed_c);
|
||||
// Do the math.
|
||||
functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
|
||||
} else {
|
||||
functor.evaluate(shared_load_stream.fragment(), fragment_d);
|
||||
}
|
||||
|
||||
// Transform D fragment.
|
||||
typename GlobalTransformerD::OutputFragment global_transformed_d;
|
||||
transformer_d.transform(fragment_d, global_transformed_d);
|
||||
|
||||
// Copy the results to global memory.
|
||||
global_store_iterator.store_post_increment(global_transformed_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
/// The dimensions of the GEMM.
|
||||
Coord<3> problem_size;
|
||||
// The functor.
|
||||
Functor functor;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,371 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of the GEMM epilogue.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The accumulators.
|
||||
typename Accumulators_,
|
||||
/// The iterator to load C from global memory.
|
||||
typename GlobalLoadIteratorC_,
|
||||
/// The transformer for C.
|
||||
typename GlobalTransformerC_,
|
||||
/// The transformer for D.
|
||||
typename GlobalTransformerD_,
|
||||
/// The iterator to store D to global memory.
|
||||
typename GlobalStoreIteratorD_,
|
||||
/// The iterator to store D to shared memory.
|
||||
typename SharedStoreIteratorD_,
|
||||
/// The shared store transformer for D.
|
||||
typename SharedStoreTransformerD_,
|
||||
/// The stream to load D from shared memory.
|
||||
typename SharedLoadStreamD_,
|
||||
/// The number of iterations in the epilogue.
|
||||
typename Iterations_,
|
||||
/// The iterations strides.
|
||||
typename Delta_,
|
||||
/// The functor to be used in the epilogue.
|
||||
typename Functor_,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct GemmEpilogueTraits {
|
||||
//
|
||||
/// The output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The number of iterations.
|
||||
/// The accumulators.
|
||||
typedef Accumulators_ Accumulators;
|
||||
/// The iterator for C in global memory.
|
||||
typedef GlobalLoadIteratorC_ GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef GlobalTransformerC_ GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef GlobalTransformerD_ GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef GlobalStoreIteratorD_ GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
|
||||
/// The stream to store D in shared memory.
|
||||
typedef SharedLoadStreamD_ SharedLoadStreamD;
|
||||
/// typedef typename GemmConfig::EpilogueIterations Iterations;
|
||||
typedef Iterations_ Iterations;
|
||||
/// The iterations strides.
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// The functor in charge of the math.
|
||||
typedef Functor_ Functor;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The long index
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The scalar.
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The strides for H and W in the different iterations of the epilogue.
|
||||
Index stride_h, stride_w;
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadIteratorC::Params iterator_c;
|
||||
|
||||
/// Batch stride for C matrix
|
||||
LongIndex batch_stride_C;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreIteratorD::Params iterator_d;
|
||||
|
||||
/// Batch stride for C matrix
|
||||
LongIndex batch_stride_D;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreIteratorD::Params shared_store_iterator_d;
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
/// The functor params.
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// At the end of the H iteration, we jump over a number of columns.
|
||||
this->stride_h = desc.D.leading_dim() * Delta::kH;
|
||||
// Nothing to do here.
|
||||
this->stride_w = 0;
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = iterator_c.initialize(desc.C.data(),
|
||||
desc.C.leading_dim(),
|
||||
desc.C.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
|
||||
batch_stride_C = desc.batch_stride_C;
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
error_code = iterator_d.initialize(desc.D.data(),
|
||||
desc.D.leading_dim(),
|
||||
desc.D.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
|
||||
batch_stride_D = desc.batch_stride_D;
|
||||
|
||||
return error_code;
|
||||
}
|
||||
};
|
||||
|
||||
/// The shared memory storage to exchange data.
|
||||
union StreamSharedStorage {
|
||||
// The storage for the store iterator.
|
||||
typename SharedStoreIteratorD::SharedStorage store;
|
||||
// The storage for the store iterator.
|
||||
typename SharedLoadStreamD::SharedStorage load;
|
||||
};
|
||||
|
||||
/// The shared memory to swizzle the data in the epilogue.
|
||||
struct SharedStorage {
|
||||
// The storage for the shared stream D.
|
||||
StreamSharedStorage shared_stream;
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ScalarD* data() { return reinterpret_cast<ScalarD*>(&shared_stream.load); }
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct GemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of iterations in the epilogue.
|
||||
typedef Shape<1,
|
||||
GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
|
||||
GemmConfig_::kAccumulatorsPerLdsB,
|
||||
GemmConfig_::kAccumulatorsPerLdsB>
|
||||
Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0,
|
||||
GemmConfig_::kAccumulatorsPerLdsB*(
|
||||
GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
|
||||
0>
|
||||
Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef GemmSharedStoreTileDTraits<
|
||||
// The pointer is float.
|
||||
// typename Functor::Scalar,
|
||||
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
|
||||
// In this case Functor::ScalarAccum is needed
|
||||
typename Functor::ScalarAccum,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsD,
|
||||
// The skew -- 128 / sizeof(ScalarD) / kScalarsPerStsD is the number of threads involved in
|
||||
// a single STS. We divide by 2 as our objective is to add a skew to the odd threads to
|
||||
// avoid bank conflicts between odd and even threads.
|
||||
128 / sizeof(typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
|
||||
GemmConfig_::kScalarsPerStsD>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef GemmSharedLoadTileDTraits<
|
||||
// The pointer is float.
|
||||
// typename Functor::Scalar,
|
||||
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
|
||||
// In this case Functor::ScalarAccum is needed
|
||||
typename Functor::ScalarAccum,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of columns of the output tile written by iteration.
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
/// The stream to load D.
|
||||
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM config.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper to create the traits class.
|
||||
typename Helper_ = GemmEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename GemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The stream to load D from shared memory.
|
||||
typename Helper_::SharedLoadStreamD,
|
||||
// The number of iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,255 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing
|
||||
to shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_>
|
||||
|
||||
struct GlobalLoadStream {
|
||||
/// Indicates the type of GEMM operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
/// The store iterator to write to shared memory.
|
||||
typedef StoreIterator_ StoreIterator;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
|
||||
/// The scalar type of the iterator.
|
||||
typedef typename LoadIterator::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::LongIndex LongIndex;
|
||||
/// The tile
|
||||
typedef typename LoadIterator::Tile Tile;
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<typename StoreIterator::Scalar, typename StoreIterator::Tile>
|
||||
ThreadblockTileStorage;
|
||||
|
||||
/// Tensor reference to threadblock tile
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
/// Batch stride in global memory
|
||||
LongIndex batch_stride;
|
||||
|
||||
// The store iterator.
|
||||
typename StoreIterator::Params store_iterator;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
|
||||
// Offset to residue for the last partition
|
||||
Index offset_to_residue_last_partition;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
LongIndex batch_stride_,
|
||||
Index ldm,
|
||||
Index offset_to_residue_,
|
||||
Index offset_to_residue_last_partition_) {
|
||||
|
||||
int error_code = load_iterator.initialize(pointer, ldm, ldm);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
batch_stride = batch_stride_;
|
||||
offset_to_residue = offset_to_residue_;
|
||||
offset_to_residue_last_partition = offset_to_residue_last_partition_;
|
||||
|
||||
return store_iterator.initialize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Index get_offset_to_residue() {
|
||||
if (blockIdx.z == gridDim.z - 1) { //last partition
|
||||
return offset_to_residue_last_partition;
|
||||
}
|
||||
else {
|
||||
return offset_to_residue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Contains private storage in shared memory needed by the objects within this class. Note,
|
||||
/// this is *NOT* the shared memory allocation for the GEMM threadblock tile. That necessarily
|
||||
/// exists outside this class, as it is also needed by the warp-level shared=>RF stream.
|
||||
struct SharedStorage {};
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
|
||||
CUTLASS_HOST_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
|
||||
bool const kKstrided =
|
||||
GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
|
||||
Coord<3> tile_coord = ProjectOperand<kOperand, kKstrided>::project(coord);
|
||||
return make_Coord(
|
||||
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(
|
||||
Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
ThreadblockTileRef const& threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& _threadblock_offset)
|
||||
: params(_params),
|
||||
threadblock_offset(project_coordinate(_threadblock_offset)),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(params.store_iterator, threadblock_tile_ref.data()) {
|
||||
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
store_iterator.inc_stage();
|
||||
}
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
load_iterator.residue(k);
|
||||
if (!skip_clear) {
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
Index kResidue = k % kTileK;
|
||||
if (kResidue) {
|
||||
residue(kResidue);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the first tile
|
||||
CUTLASS_DEVICE void rollback(void) {
|
||||
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
|
||||
|
||||
int const kBlock = kOperand == GemmOperand::kA
|
||||
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
|
||||
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE GlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters
|
||||
Params params;
|
||||
/// Threadblock offset
|
||||
Coord<3> threadblock_offset;
|
||||
/// Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
/// The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
/// The store iterator.
|
||||
StoreIterator store_iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,614 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing to global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functor reshapes a tile of threads to match a tile of data. The idea is that when
|
||||
// the user wants to build the iterator traits, he/she may want to specify the tile independently
|
||||
// from the number of scalars loaded/stored per instruction. For example, in the row-major version
|
||||
// with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
|
||||
// each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
|
||||
// threads has to change. The code below detects that and correct the code automatically - it is
|
||||
// a helper when the user does not specify the right configuration.
|
||||
|
||||
template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
|
||||
struct ReshapeThreads {
|
||||
typedef Threads_ Threads;
|
||||
};
|
||||
|
||||
template <typename Tile_, typename Threads_>
|
||||
struct ReshapeThreads<Tile_, Threads_, true> {
|
||||
typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct GemmGlobalTileTraits {
|
||||
/// Identity of the operand
|
||||
static GemmOperand::Kind const kOperand = kOperand_;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kAccessSize_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
|
||||
/// The tile shape
|
||||
typedef Tile_ Tile;
|
||||
/// The vectorized tile shape
|
||||
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile VectorizedTile;
|
||||
/// The threads shape
|
||||
typedef typename ReshapeThreads<VectorizedTile, Threads_>::Threads Threads;
|
||||
/// The relative offset between two elements in the H/W dimension in adjacent threads.
|
||||
typedef Shape<1, 1, VectorizedTile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1,
|
||||
VectorizedTile::kH / Threads::kH,
|
||||
VectorizedTile::kW / Threads::kW,
|
||||
VectorizedTile::kC / kAccessSize>
|
||||
Iterations;
|
||||
|
||||
typedef GemmMultiplicandTraits<Tile, kOperand, kLayout> MultiplicandTraits;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
|
||||
struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// The stride in the H dimension.
|
||||
static int const kStrideH = kStrideH_;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
typedef typename Base::Threads Threads;
|
||||
|
||||
typedef typename Base::ThreadsDelta ThreadsDelta;
|
||||
|
||||
typedef typename Base::ImmediateOffsetStrides ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorAb
|
||||
: public TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> This_; /// The base class.
|
||||
typedef TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
/// The tile
|
||||
typedef typename TileTraits_::Tile Tile;
|
||||
/// Fragment type loaded by the iterator
|
||||
typedef typename Base::Fragment Fragment;
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// Long index
|
||||
typedef long long LongIndex;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
typedef cutlass::PredicateVector<ShapeCount<typename Base::Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Iterator parameters type
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr,
|
||||
Index stride_d,
|
||||
Index stride_h) {
|
||||
return BaseParams::initialize(ptr, stride_d, stride_h, kAdvance == IteratorAdvance::kH ? 0 : 1);
|
||||
}
|
||||
};
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The predicates.
|
||||
PredicateVector predicates;
|
||||
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
|
||||
// Setup the masks to control loads.
|
||||
predicates.fill(0);
|
||||
|
||||
// Fill in the bits of the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
flag =
|
||||
flag &&
|
||||
(h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
|
||||
bounds[1];
|
||||
} else {
|
||||
flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
|
||||
}
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
predicates.set(bit, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Setup the pointer.
|
||||
params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(threadblock_offset[2] + thread_offset[2]));
|
||||
|
||||
}
|
||||
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_w() { Base::inc_w(); }
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_HOST_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_HOST_DEVICE void residue(Index k) {
|
||||
// Update the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
Index offset = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
offset += thread_offset[1] + h * Base::Delta::kH + d * Base::Delta::kD;
|
||||
} else {
|
||||
offset += thread_offset[2] + w * Base::Delta::kW;
|
||||
}
|
||||
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
if (offset >= k) {
|
||||
predicates.set(bit, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Is the valid?
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
return predicates[bit];
|
||||
}
|
||||
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) {
|
||||
|
||||
LongIndex _offset = offset.template dot<LongIndex>(
|
||||
make_Coord(params.stride_d, params.stride_h, params.stride_w)
|
||||
);
|
||||
|
||||
params.pointer += _offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
|
||||
CUTLASS_HOST_DEVICE Index stride_advance(void) {
|
||||
Index stride = params.stride_h;
|
||||
if (kAdvance == IteratorAdvance::kW) {
|
||||
stride = params.stride_w;
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
load_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The index.
|
||||
typedef long long LongIndex;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the D dimension
|
||||
long long stride_d;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_advance, inc_h;
|
||||
/// The strides to increment the predicate offset
|
||||
Index predicate_inc_advance, predicate_inc_h;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
int stride_d_,
|
||||
Index ldm,
|
||||
Index bound,
|
||||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Stride per batch
|
||||
stride_d = stride_d_;
|
||||
// Each column of the matrix.
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ldm;
|
||||
// Each thread output 1 column per iteration. The stride between columns is given by the
|
||||
// number of scalars that are loaded per LDS for B.
|
||||
inc_h = ldm * TileTraits_::kStrideH;
|
||||
inc_advance =
|
||||
(ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
|
||||
|
||||
predicate_offset = bound;
|
||||
predicate_inc_h = TileTraits_::kStrideH;
|
||||
predicate_inc_advance =
|
||||
-((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h,
|
||||
Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h,
|
||||
Index _predicate_offset) {
|
||||
this->pointer = pointer;
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
inc_advance = _inc_advance;
|
||||
inc_h = _inc_h;
|
||||
predicate_inc_advance = _predicate_inc_advance;
|
||||
predicate_inc_h = _predicate_inc_h;
|
||||
predicate_offset = _predicate_offset;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters.
|
||||
Params params;
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int offset = 0,
|
||||
int pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
params.pointer += ((h * params.stride_h + w) + offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_HOST_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) {
|
||||
LongIndex _offset = offset.template dot<LongIndex>(
|
||||
make_Coord(params.stride_d, params.stride_h, 1)
|
||||
);
|
||||
params.pointer += _offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a single fragment element from memory.
|
||||
CUTLASS_HOST_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Stores a single fragment element into memory.
|
||||
CUTLASS_HOST_DEVICE void store_element(
|
||||
typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::store(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Test the validity of the
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// add pointer offset
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
|
||||
|
||||
/// Loads and increments iterator
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
load_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
store_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,274 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Traits_>
|
||||
struct GemmMainloop {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// The traits.
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// The GEMM mainloop
|
||||
typedef typename Traits::KernelClass KernelClass;
|
||||
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar in the epilogue.
|
||||
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::Epilogue::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::Epilogue::ScalarD ScalarD;
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Define the mainloop iteration size
|
||||
typedef typename Traits::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
// Number of warp-level multiply-accumulate steps executed by each warp.
|
||||
static Index const kWarpGemmSteps =
|
||||
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
/*
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
|
||||
*/
|
||||
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
|
||||
/// SharedStorage object
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Fetches global stream pair
|
||||
template <bool Residue>
|
||||
CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
Index outer_k) {
|
||||
// If residue portion and not calculating residue in prolog, update residue predicates now.
|
||||
if (Residue) {
|
||||
global_to_shared_stream.residue(outer_k);
|
||||
}
|
||||
global_to_shared_stream.copy();
|
||||
}
|
||||
|
||||
/// Computes a warp-level GEMM on data held in shared memory
|
||||
template <bool Residue, bool LastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
typename Traits::SharedStream& shared_load_stream,
|
||||
typename MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
|
||||
// Whether to load global stream before loading shared stream
|
||||
const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);
|
||||
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (kGlobalStreamFirst && !LastIteration) {
|
||||
fetch_global<Residue>(global_to_shared_stream, outer_k);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kWarpGemmSteps; ++step) {
|
||||
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy((step + 1) % kWarpGemmSteps);
|
||||
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
|
||||
fetch_global<Residue>(global_to_shared_stream, outer_k);
|
||||
}
|
||||
|
||||
if (step == kWarpGemmSteps - 2) {
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::OutputTile>());
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// Get the bounds for each thread, it maybe different than problem_size
|
||||
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
|
||||
params.partitionK_range);
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_to_shared_stream(
|
||||
params.global_to_shared_stream,
|
||||
shared_storage.main_loop.global_to_shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference(),
|
||||
bounds,
|
||||
threadblock_offset);
|
||||
|
||||
// update A and B pointer offset based on batch_id and batch_stride_offset
|
||||
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear;
|
||||
|
||||
// Deal with residue in prolog.
|
||||
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_to_shared_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the first tile (if residue exists).
|
||||
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedStream shared_load_stream(
|
||||
params.shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference());
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename MultiplyAdd::Accumulators accumulators;
|
||||
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// Initial index
|
||||
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
// problem_size[0] might be bigger than bounds[0]
|
||||
Index outer_k = bounds[0] - Traits::OutputTile::kD;
|
||||
// Check if we are computing residue in prolog or not.
|
||||
if (Traits::GemmConfig::kResidueInProlog) {
|
||||
// Execute all mainloop iterations but the last one.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
consume_tile<false, true>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
} else {
|
||||
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
|
||||
// consideration for K-residue or predicate updates. This improves the steady state of some
|
||||
// kernels.
|
||||
if (Traits::GemmConfig::kResidueSeparate) {
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remaining tiles with K-residue predicate updates enabled.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<true, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
|
||||
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,141 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear
|
||||
memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to describe attributes of GEMM matrix operands
|
||||
template <GemmOperand::Kind kOperand_, MatrixLayout::Kind kLayout_>
|
||||
struct GemmOperandTraitsAb {
|
||||
static const bool Congruous =
|
||||
(kOperand_ == GemmOperand::kA ^ kLayout_ == MatrixLayout::kRowMajor);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmOperand::Kind kOperand_, typename Tile_>
|
||||
struct GetExtent;
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kA, Tile_> {
|
||||
static const int kExtent = Tile_::kW;
|
||||
};
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kB, Tile_> {
|
||||
static const int kExtent = Tile_::kH;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Determines the shape of a multiplicand tile in terms of strided (H) and contiguous (W)
|
||||
/// dimensions
|
||||
template <typename ThreadBlockTile_, GemmOperand::Kind Usage, MatrixLayout::Kind Layout>
|
||||
struct GemmMultiplicandTraits {
|
||||
// Only defined for A or B
|
||||
static_assert(Usage == GemmOperand::kA || Usage == GemmOperand::kB,
|
||||
"MultiplicandTileShape defined only for A or B operands.");
|
||||
|
||||
/// Shape of GEMM thread block tile (K, N, M)
|
||||
typedef ThreadBlockTile_ ThreadBlockTile;
|
||||
|
||||
/// Identifies multiplicand
|
||||
static GemmOperand::Kind const kUsage = Usage;
|
||||
|
||||
/// Layout of tile
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
|
||||
// True if K is the strided dimension
|
||||
static bool const kKstrided = (kUsage == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor);
|
||||
|
||||
/// Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand
|
||||
typedef typename platform::conditional<
|
||||
kKstrided,
|
||||
Shape<1, ThreadBlockTile::kD, GetExtent<Usage, ThreadBlockTile>::kExtent>,
|
||||
Shape<1, GetExtent<Usage, ThreadBlockTile>::kExtent, ThreadBlockTile::kD> >::type Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Project's a coordinate (K, N, M) onto inner and outer dimensions defined for each
|
||||
/// operand.
|
||||
template <GemmOperand::Kind operand, bool Kstrided = true>
|
||||
struct ProjectOperand;
|
||||
|
||||
/// Project A operand - (0, K, M)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kA, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[2]);
|
||||
} else {
|
||||
return make_Coord(0, coord[2], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project B operand - (0, K, N)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kB, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[1]);
|
||||
} else {
|
||||
return make_Coord(0, coord[1], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project C operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kC, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
/// Project D operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kD, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,142 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for managing loading and storing fragments to shared memory in the
|
||||
efficient GEMM pipeline.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
|
||||
struct SharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ Iterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Scalar data type
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
/// Reference type to a tensor
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The iterator params.
|
||||
typename Iterator::Params iterator;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize() { return iterator.initialize(); }
|
||||
};
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
typedef typename Iterator::Storage SharedStorage;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, TensorRef const &ref) {
|
||||
this->initialize(params, ref);
|
||||
}
|
||||
|
||||
/// Initialize the stream.
|
||||
CUTLASS_DEVICE void initialize(Params const ¶ms, TensorRef const &ref) {
|
||||
// The iterator.
|
||||
iterator = Iterator(params.iterator, ref.data());
|
||||
// The transformer.
|
||||
transformer = Transformer();
|
||||
}
|
||||
|
||||
/// Clears the fragment
|
||||
CUTLASS_DEVICE void clear() {
|
||||
fetched[0].clear();
|
||||
fetched[1].clear();
|
||||
transformed[0].clear();
|
||||
transformed[1].clear();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
iterator.load_post_increment(fetched[0]);
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() { transformer.transform(fetched[0], transformed[0]); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
/// Returns the fragment for the given step
|
||||
CUTLASS_DEVICE TransformedFragment &fragment(int step = 0) { return transformed[step % 2]; }
|
||||
|
||||
/// Returns the fragment for the given step
|
||||
CUTLASS_DEVICE TransformedFragment const &fragment(int step = 0) const {
|
||||
return transformed[step % 2];
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() { iterator.inc_stage(); }
|
||||
|
||||
/// The iterator.
|
||||
Iterator iterator;
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,417 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing tiles to and from shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_>
|
||||
struct GemmSharedStoreTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Tile::kC, kScalarsPerSts_> ThreadsStrides;
|
||||
/// The skew.
|
||||
static int const kSkew = 0;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1,
|
||||
Tile::kH / Threads::kH,
|
||||
Tile::kW / Threads::kW,
|
||||
Tile::kC / Threads::kC / kAccessSize>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize>
|
||||
ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_, int kSkew_>
|
||||
struct GemmSharedStoreWithSkewTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skews.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Shape<Tile_::kD, Tile_::kH, Tile_::kW + kSkew_>,
|
||||
kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The number of scalars per STS.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, kScalarsPerSts_, ShapeCount<Tile>::kHwc / Threads::kW> ThreadsStrides;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileATraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
// static int const kScalarsPerLds = kScalarsPerLds_;
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_row = (threadIdx.x & 0x0e) / 2;
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate vector.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileBTraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// The warp in the slice.
|
||||
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_col = warp_in_slice / Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kScalarsPerSts_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedStoreTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, kScalarsPerThread / kAccessSize> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// The warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
|
||||
// The position of the warp in the 2D tile.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
int const warp_col = warp / Warps::kW;
|
||||
|
||||
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
|
||||
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
|
||||
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
|
||||
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
|
||||
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
|
||||
|
||||
// Odd threads go to the second half of shared memory.
|
||||
int const row = threadIdx.x & 0x01;
|
||||
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
|
||||
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
|
||||
// Embed the offset in a 4D coords.
|
||||
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kTileH_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
|
||||
/// conflicts in the epilogue.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
|
||||
// Compute the number of iterations per warp in the Tile::kH dimension.
|
||||
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
|
||||
|
||||
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
|
||||
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
|
||||
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
|
||||
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
|
||||
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
|
||||
// to keep the number of elements to reduce for split-K.
|
||||
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
|
||||
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
|
||||
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
|
||||
|
||||
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
|
||||
// threads to grab the other elements.
|
||||
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
|
||||
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
|
||||
ImmediateOffsetStrides;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Each warp works on a different column.
|
||||
int const h = threadIdx.x / kWarpSize;
|
||||
// Compute the row.
|
||||
int const w = (threadIdx.x & (kWarpSize - 1)) * kAccessSize;
|
||||
int offset = 0;
|
||||
if (Iterations::kH == 1) {
|
||||
int const row = h & 0x1;
|
||||
int const col = h / 2;
|
||||
offset = row * ShapeCount<Tile>::kWc + col * OutputTile::kW * Iterations::kD + w;
|
||||
} else {
|
||||
offset = h * OutputTile::kW * Iterations::kD + w;
|
||||
}
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,270 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a pair of GEMM tile streams
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Collect the global load streams for multiplicands.
|
||||
template <typename StreamA_, typename StreamB_, bool kResidueInProlog_>
|
||||
struct GlobalLoadStreamPair {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
typedef StreamA_ StreamA;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
typedef StreamB_ StreamB;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
/// Parameters object for StreamA
|
||||
typename StreamA::Params stream_a;
|
||||
|
||||
/// Parameters object for StreamB
|
||||
typename StreamB::Params stream_b;
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Constructs a global load stream pair Params object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(typename StreamA::Params const &_params_A, typename StreamB::Params const &_params_B)
|
||||
: stream_a(_params_A), stream_b(_params_B) {}
|
||||
};
|
||||
|
||||
/// Assumes the A stream defines the index type
|
||||
typedef typename StreamA::Index Index;
|
||||
|
||||
/// Shared memory allocation for threadblock-scoped GEMM tile
|
||||
typedef ZipTileAllocation<typename StreamA::ThreadblockTileStorage,
|
||||
typename StreamB::ThreadblockTileStorage>
|
||||
ThreadblockTileStorage;
|
||||
|
||||
/// ZipTensorRef to threadblock tiles
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
/// Defines a structure containing shared storage for each pair
|
||||
struct SharedStorage {
|
||||
typename StreamA::SharedStorage stream_a;
|
||||
typename StreamB::SharedStorage stream_b;
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
StreamA stream_a;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
StreamB stream_b;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStreamPair(Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0))
|
||||
: stream_a(params.stream_a,
|
||||
shared_storage.stream_a,
|
||||
threadblock_tile_ref.first,
|
||||
bounds,
|
||||
block_offset),
|
||||
stream_b(params.stream_b,
|
||||
shared_storage.stream_b,
|
||||
threadblock_tile_ref.second,
|
||||
bounds,
|
||||
block_offset) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GlobalLoadStreamPair & operator+=(Coord<3> const offset) {
|
||||
stream_a += offset;
|
||||
stream_b += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GlobalLoadStreamPair & add_batch_offset(int batch_id) {
|
||||
stream_a.add_batch_offset(batch_id);
|
||||
stream_b.add_batch_offset(batch_id);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
|
||||
stream_a.copy();
|
||||
|
||||
stream_b.copy();
|
||||
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
|
||||
stream_b.commit();
|
||||
|
||||
}
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
stream_a.residue(k, skip_clear);
|
||||
stream_b.residue(k, skip_clear);
|
||||
}
|
||||
|
||||
/// Move to residue.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
if (kResidueInProlog_) {
|
||||
stream_a.move_to_residue(k, kTileK);
|
||||
stream_b.move_to_residue(k, kTileK);
|
||||
} else if (k < kTileK) {
|
||||
residue(k, true);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile.
|
||||
CUTLASS_DEVICE void rollback(bool kRollback) {
|
||||
if (kResidueInProlog_ && kRollback) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Collect the global load streams for multiplicands.
|
||||
template <typename StreamA_, typename StreamB_>
|
||||
struct SharedStreamPair {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
typedef StreamA_ StreamA;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
typedef StreamB_ StreamB;
|
||||
|
||||
/// Parameters object passed to load iterators
|
||||
struct Params {
|
||||
///
|
||||
typename StreamA::Params stream_a;
|
||||
|
||||
///
|
||||
typename StreamB::Params stream_b;
|
||||
};
|
||||
|
||||
/// Shared memory allocation for threadblock-scoped GEMM tile
|
||||
typedef ZipTensorRef<typename StreamA::TensorRef,
|
||||
typename StreamB::TensorRef >
|
||||
ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The stream for A.
|
||||
StreamA stream_a;
|
||||
|
||||
/// The stream for B.
|
||||
StreamB stream_b;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construct with the composable structure
|
||||
CUTLASS_DEVICE SharedStreamPair(Params const ¶ms, ThreadblockTileRef const &threadblock_tile_ref)
|
||||
: stream_a(params.stream_a, threadblock_tile_ref.first),
|
||||
stream_b(params.stream_b, threadblock_tile_ref.second) {}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
stream_a.copy(step);
|
||||
stream_b.copy(step);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
stream_a.commit(step);
|
||||
stream_b.commit(step);
|
||||
}
|
||||
|
||||
/// Clears all fragments
|
||||
CUTLASS_DEVICE
|
||||
void clear() {
|
||||
stream_a.clear();
|
||||
stream_b.clear();
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE
|
||||
typename StreamA::TransformedFragment const &fragment_a(int step) const {
|
||||
return stream_a.fragment(step);
|
||||
}
|
||||
|
||||
/// The fragment B.
|
||||
CUTLASS_DEVICE
|
||||
typename StreamB::TransformedFragment const &fragment_b(int step) const {
|
||||
return stream_b.fragment(step);
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
stream_a.inc_stage();
|
||||
stream_b.inc_stage();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,808 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of complete GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/gemm_mainloop.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for A.
|
||||
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^T.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^T.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM configuration.
|
||||
typename GemmConfig_,
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typename GlobalLoadStreamA_,
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typename GlobalLoadStreamB_,
|
||||
/// The stream to load A from shared memory.
|
||||
typename SharedLoadStreamA_,
|
||||
/// The stream to load B from shared memory.
|
||||
typename SharedLoadStreamB_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = IdentityBlockSwizzle,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The tool used to clear accumulators.
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Element> >
|
||||
|
||||
struct GemmTraits {
|
||||
/// This traits
|
||||
typedef GemmTraits<GemmConfig_,
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_> This_;
|
||||
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig::OutputTile OutputTile;
|
||||
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
|
||||
/// The layout of A.
|
||||
static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
|
||||
/// The scalar for A.
|
||||
typedef typename GlobalLoadStreamA_::Scalar ScalarA;
|
||||
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStreamB_ GlobalLoadStreamB;
|
||||
/// The layout of B.
|
||||
static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
|
||||
/// The scalar for B.
|
||||
typedef typename GlobalLoadStreamB_::Scalar ScalarB;
|
||||
|
||||
/// The iterator for A to load from shared memory.
|
||||
typedef SharedLoadStreamA_ SharedLoadStreamA;
|
||||
/// The iterator for B to load from shared memory.
|
||||
typedef SharedLoadStreamB_ SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The epilogue.
|
||||
typedef Epilogue_ Epilogue;
|
||||
/// The scalars in the epilogue.
|
||||
typedef typename Epilogue::ScalarC ScalarC;
|
||||
typedef typename Epilogue::ScalarD ScalarD;
|
||||
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// Clear the accumulators.
|
||||
typedef ClearAccumulators_ ClearAccumulators;
|
||||
|
||||
/// Assemble the global load streams for A/B.
|
||||
typedef GlobalLoadStreamPair<GlobalLoadStreamA,
|
||||
GlobalLoadStreamB,
|
||||
GemmConfig::kResidueInProlog>
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
|
||||
/// Assemble the shared load streams for A/B.
|
||||
typedef SharedStreamPair<SharedLoadStreamA, SharedLoadStreamB> SharedStream;
|
||||
|
||||
/// Parameters object constructable on the host.
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
|
||||
/// GEMM problem size
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The K range for every partition except the last one
|
||||
int partitionK_range;
|
||||
|
||||
/// Parameters object for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// Parameters object for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// there is no partitionK in the default case
|
||||
partitionK_range = problem_size[0];
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A.data(),
|
||||
desc.batch_stride_A,
|
||||
desc.A.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B.data(),
|
||||
desc.batch_stride_B,
|
||||
desc.B.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a GEMM params using a BLAS-like API
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a batched GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, batch_count),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
batch_stride_A,
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
batch_stride_B,
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
batch_stride_C,
|
||||
TensorRef<ScalarD, 2>(d_d, ldd),
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc,
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1 // each partition will be mulitples of partitionK_multiple_
|
||||
) {
|
||||
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
|
||||
// the problem_size of each batch is (lastK_size, n, m)
|
||||
// add more comments here
|
||||
// the k range for every batch excpet the last one
|
||||
//assert(partitionK_count_ > 0);
|
||||
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
|
||||
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
|
||||
// the k range of the last batch
|
||||
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
|
||||
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
|
||||
|
||||
assert((partitionK_range % partitionK_multiple_) == 0);
|
||||
assert(partitionK_range > 0);
|
||||
assert((lastK_range % partitionK_multiple_) == 0);
|
||||
assert(lastK_range > 0);
|
||||
|
||||
int k_size = lastK_range;
|
||||
int lda = partitonK_desc.A.stride(0);
|
||||
int ldb = partitonK_desc.B.stride(0);
|
||||
int ldc = partitonK_desc.C.stride(0);
|
||||
int ldd = partitonK_desc.D.stride(0);
|
||||
int n = partitonK_desc.problem_size.n();
|
||||
|
||||
|
||||
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
|
||||
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
|
||||
long long int batch_stride_C = ldc * n;
|
||||
long long int batch_stride_D = ldd * n;
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
|
||||
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
|
||||
partitonK_desc.alpha,
|
||||
partitonK_desc.A,
|
||||
batch_stride_A,
|
||||
partitonK_desc.B,
|
||||
batch_stride_B,
|
||||
partitonK_desc.beta,
|
||||
partitonK_desc.C,
|
||||
batch_stride_C,
|
||||
partitonK_desc.D,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A.data(),
|
||||
desc.batch_stride_A,
|
||||
desc.A.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B.data(),
|
||||
desc.batch_stride_B,
|
||||
desc.B.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
|
||||
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for the main loop + prologue.
|
||||
struct MainLoopSharedStorage {
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
|
||||
/// Storage for clearing accumulators
|
||||
typename ClearAccumulators::SharedStorage clear;
|
||||
};
|
||||
|
||||
/// The storage in shared memory.
|
||||
union SharedStorage {
|
||||
// The storage for the main loop.
|
||||
MainLoopSharedStorage main_loop;
|
||||
// The storage for the epilogue.
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
|
||||
SharedLoadStreamB::Iterator::kRequiresLoadFence) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
|
||||
struct SimplifiedGemmTraitsHelper {
|
||||
/// The global iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The data converter for A before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The global iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
/// The data converter for B before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The config for the GEMM.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
// The configuration for the A matrix.
|
||||
typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
|
||||
// The configuration for the B matrix.
|
||||
typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
|
||||
// The helper class to create the streams and iterators.
|
||||
typename Helper_ =
|
||||
SimplifiedGemmTraitsHelper<GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_> >
|
||||
struct SimplifiedGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
Epilogue_,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,90 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tile traits used to construct global tile iterator for HGEMM. This is intended to
|
||||
partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate
|
||||
memory accesses larger than 16 bits.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 2, Base::VectorizedTile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 2, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 2,
|
||||
2,
|
||||
Base::VectorizedTile::kW / Base::Threads::kW,
|
||||
Base::VectorizedTile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,110 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Specialization implementing multiply-add operation on half-precision floating point
|
||||
fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased for compatibility. Will be removed for CUTLASS v2.0.
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<half, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Make sure there's an even number of elements in both dimensions.
|
||||
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kH >= 2 && AccumulatorsPerThread::kW >= 2,
|
||||
"HGEMM expects at least 2x2 accmulator tiles per thread.");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// The inputs.
|
||||
__half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
|
||||
// The offsets in the output fragment.
|
||||
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
|
||||
// Compute the product a[i] * b[j].low.
|
||||
d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
|
||||
// Compute the product a[i] * b[j].high.
|
||||
d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,94 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in
|
||||
shared memory for multiplicands.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct HgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The input fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be half fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, half>::value), "Works on half");
|
||||
|
||||
/// The number of elements must be a multiple of 2.
|
||||
static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE HgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
// The indices to read two consecutive "rows".
|
||||
int const i0 = 2 * d + 0;
|
||||
int const i1 = 2 * d + 1;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
|
||||
int b0, b1;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
|
||||
// The indices to store with "strides".
|
||||
int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
|
||||
dst_int[j0] = b0;
|
||||
dst_int[j1] = b1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,408 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of half-precision GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/hgemm_global_tile.h"
|
||||
#include "cutlass/gemm/hgemm_multiply_add.h"
|
||||
#include "cutlass/layout/thread/transform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 2>
|
||||
struct HgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
half,
|
||||
/// The scalar type for D.
|
||||
half,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, half, half, half>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
2,
|
||||
/// The number of scalars per STS for D.
|
||||
8,
|
||||
/// The number of scalars per LDS for D.
|
||||
2,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
false
|
||||
> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K ) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32(the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct HgemmTraitsHelper {
|
||||
/// The HGEMM config.
|
||||
typedef HgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef HgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef HgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
|
||||
/// The functor to do the multiply-add in the main loop.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_> GemmEpilogueTraits;
|
||||
/// The epilogue.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<half>,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = HgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
EpilogueFunctor_,
|
||||
ThreadGemmShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
struct HgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,318 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and
|
||||
floating-point output matrix formats.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/igemm_global_tile.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmFloatToInt8Converter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<float, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<int8_t, kElements_> OutputFragment;
|
||||
|
||||
// We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
|
||||
// The outputs.
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Iterate over the floats and pack them together to produce ints.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the float4.
|
||||
float4 f4 = src_f4[i];
|
||||
|
||||
// Clamp the 4 elements of the floats to the [-128, +127] range.
|
||||
float x = fmaxf(-128.f, fminf(127.f, f4.x));
|
||||
float y = fmaxf(-128.f, fminf(127.f, f4.y));
|
||||
float z = fmaxf(-128.f, fminf(127.f, f4.z));
|
||||
float w = fmaxf(-128.f, fminf(127.f, f4.w));
|
||||
|
||||
// Convert to integers.
|
||||
int ix = (int)x;
|
||||
int iy = (int)y;
|
||||
int iz = (int)z;
|
||||
int iw = (int)w;
|
||||
|
||||
// Extract the lower bytes to build an int32 with 4 int8.
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
|
||||
|
||||
// Store the int.
|
||||
dst_int[i] = ix;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmGlobalStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
|
||||
typedef IgemmFloatToInt8Converter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmInt8ToFloatConverter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<int8_t, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<float, kElements_> OutputFragment;
|
||||
|
||||
// We are unpacking 4 int8s from int32.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
// The outputs.
|
||||
float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
|
||||
|
||||
// Iterate over the int8 and unpack them together to produce floats.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the int.
|
||||
int ix, iy, iz, iw = src_int[i];
|
||||
|
||||
// Extract the 4 bytes.
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
|
||||
|
||||
// The floats.
|
||||
float fx, fy, fz, fw;
|
||||
|
||||
// Convert to floats (make sure we generate I2F.F32.S8).
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
|
||||
|
||||
// Store the float4.
|
||||
dst_f4[i] = make_float4(fx, fy, fz, fw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputScalar_>
|
||||
struct IgemmGlobalLoadTransformer {
|
||||
typedef Convert<InputFragment_, Fragment<OutputScalar_, InputFragment_::kElements> > Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
|
||||
typedef IgemmInt8ToFloatConverter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmSharedStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
|
||||
struct IgemmEpilogueTraitsHelper
|
||||
: public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> Base;
|
||||
/// The config.
|
||||
typedef IgemmConfig_ IgemmConfig;
|
||||
|
||||
/// The scalar type of the epilogue.
|
||||
typedef typename Base::Scalar Scalar;
|
||||
/// The iterations.
|
||||
typedef typename Base::Iterations Iterations;
|
||||
/// The iterations strides.
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalLoadTileTraits GlobalLoadTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits> GlobalLoadIteratorC;
|
||||
/// The fragment that needs to be produced by the load iterator.
|
||||
typedef typename GlobalLoadIteratorC::Fragment GlobalFragmentC;
|
||||
/// The transformer from loaded data to math fragment.
|
||||
typedef
|
||||
typename IgemmGlobalLoadTransformer<GlobalFragmentC, Scalar>::Transformer GlobalTransformerC;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalStoreTileTraits GlobalStoreTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits> GlobalStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename GlobalStoreIteratorD::Fragment GlobalFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef
|
||||
typename IgemmGlobalStoreTransformer<Scalar, GlobalFragmentD>::Transformer GlobalTransformerD;
|
||||
|
||||
/// The traits class for the shared iterator to store D to shared memory.
|
||||
typedef typename Base::SharedStoreTileTraits SharedStoreTileTraits;
|
||||
/// The shared iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal>
|
||||
SharedStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename SharedStoreIteratorD::Fragment SharedStoreFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef typename IgemmSharedStoreTransformer<typename IgemmConfig::Accumulators::Element,
|
||||
SharedStoreFragmentD>::Transformer
|
||||
SharedStoreTransformerD;
|
||||
/// The traits class for the shared iterator to load D from shared memory.
|
||||
typedef typename Base::SharedLoadTileTraits SharedLoadTileTraits;
|
||||
/// The shared iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The config.
|
||||
typename IgemmConfig_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class to assemble the traits.
|
||||
typename Helper_ = IgemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct IgemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename IgemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename IgemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The stream to load D from shared memory.
|
||||
typename Helper_::SharedLoadStreamD,
|
||||
// The iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {
|
||||
/// Do we output in int8?
|
||||
static bool const kInt8Output =
|
||||
platform::is_same<typename IgemmConfig_::ScalarC, int8_t>::value != 0;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
|
||||
struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
Coord<3> const& _problem_size)
|
||||
: Base(params_, shared_storage_, _problem_size) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
Coord<3> const& _problem_size)
|
||||
: Base(params_, shared_storage_, _problem_size) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,135 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements tile iterators to partition the thread block tile into 2D subtiles and
|
||||
efficiently load each. Applies permute transformation to construct 'interleaved K-strided'
|
||||
data layout in which 4-element dot products from the same K index are arranged in consecutive
|
||||
locations within shared memory.
|
||||
|
||||
Supports efficient loads from shared memory to target the DP4A instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 4, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 4,
|
||||
4,
|
||||
Base::VectorizedTile::kW / Base::Threads::kW,
|
||||
Base::VectorizedTile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 4, Base::VectorizedTile::kC> ThreadsDelta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
|
||||
/// The functor to compute the thread offset.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, threadblock_offset, thread_offset_func), mask_(0xffffffff) { }
|
||||
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& threadblock_offset) {
|
||||
|
||||
Base::initialize_predicates(bounds, threadblock_offset);
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW;
|
||||
// The residue.
|
||||
int const kResidue = (int)(bounds[1] % kBlock);
|
||||
|
||||
// Compute the number of elements that are valid.
|
||||
int const left = kResidue - Base::thread_offset[2];
|
||||
if (left > 0 && left < 4) {
|
||||
mask_ = (1u << (8 * left)) - 1u;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::load_element(value, d, h, w, c);
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
|
||||
/// The mask to clean up the values.
|
||||
uint32_t mask_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,103 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements matrix multiply accumulate operation of 8-bit integer data using DP4A
|
||||
instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<4, 1, 1> InstructionShape;
|
||||
/// Shape of the thread-level GEMM (K-by-N-by-M)
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
|
||||
/// Thread-level GEMM (N-by-M) must be a multiple of 32.
|
||||
static_assert((ThreadGemmShape::kH * ThreadGemmShape::kW) % 32 == 0,
|
||||
"Thread-level GEMM (N-by-M) must be multiple of 32");
|
||||
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef int8_t ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW * 4> FragmentA;
|
||||
/// The type for B.
|
||||
typedef int8_t ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH * 4> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
// The inputs.
|
||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d[j * AccumulatorsPerThread::kW + i])
|
||||
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
|
||||
@ -1,126 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a fragment of data containing packed 8-bit integer elements.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct IgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The source fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The destination fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be int8 fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, int8_t>::value), "Works on int8");
|
||||
|
||||
/// The number of elements must be a multiple of 4.
|
||||
static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
|
||||
"Not multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
for (int h = 0; h < FragmentShape::kH / 4; ++h) {
|
||||
for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
|
||||
int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
int a2 = src_int[i2];
|
||||
int a3 = src_int[i3];
|
||||
|
||||
// // DEBUG.
|
||||
// if (threadIdx.x == 0) {
|
||||
// printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3);
|
||||
// }
|
||||
|
||||
int b0, b1, b2, b3, c0;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
|
||||
|
||||
// // DEBUG.
|
||||
// if (threadIdx.x == 0) {
|
||||
// printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3);
|
||||
// }
|
||||
|
||||
dst_int[i0] = b0;
|
||||
dst_int[i1] = b1;
|
||||
dst_int[i2] = b2;
|
||||
dst_int[i3] = b3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,553 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed
|
||||
to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output
|
||||
formats vary.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/igemm_epilogue.h"
|
||||
#include "cutlass/gemm/igemm_global_tile.h"
|
||||
#include "cutlass/gemm/igemm_multiply_add.h"
|
||||
#include "cutlass/layout/thread/transform.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_>
|
||||
struct IgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
ScalarD_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
false>
|
||||
{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename ThreadGemmShape_>
|
||||
struct IgemmConfig<OutputTile_, int8_t, ThreadGemmShape_>
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
int8_t,
|
||||
/// The scalar type for D.
|
||||
int8_t,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
4,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
4,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// If true, separate mainloop is instantiated from residue
|
||||
false,
|
||||
/// Compute residue in prolog?
|
||||
true,
|
||||
/// Launch bounds?
|
||||
false> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The global load iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The global load iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The global load iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The global load iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct IgemmTraitsHelper {
|
||||
/// The IGEMM config.
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, ThreadGemmShape_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA, Copy<typename SharedLoadIteratorA::Fragment> >
|
||||
SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB, Copy<typename SharedLoadIteratorB::Fragment> >
|
||||
SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The epilogue.
|
||||
typedef IgemmEpilogue<IgemmEpilogueTraits<GemmConfig, EpilogueFunctor_> > Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ScalarD_>
|
||||
struct IgemmEpilogueScalar {
|
||||
typedef float Scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IgemmEpilogueScalar<int> {
|
||||
typedef int Scalar;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<32, 128, 128>,
|
||||
/// The output type.
|
||||
typename ScalarD_ = int,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<typename IgemmEpilogueScalar<ScalarD_>::Scalar>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = IgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarD_,
|
||||
EpilogueFunctor_,
|
||||
ThreadGemmShape_,
|
||||
Index_> >
|
||||
struct IgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,169 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE bool is_zero(T x) {
|
||||
return x == T(0);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
|
||||
struct LinearScaling {
|
||||
// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The accumulator Type
|
||||
typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
|
||||
// The adapater.
|
||||
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
|
||||
|
||||
/// The parameters.
|
||||
struct Params {
|
||||
/// The alpha/beta scaling params.
|
||||
Scalar alpha, beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar _alpha = 0.0f, Scalar _beta = 0.0f)
|
||||
: alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) {
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling() { }
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
|
||||
|
||||
/// Method to determine whether the source accumulator matrix C is ever needed. This method
|
||||
/// may always safely return true, though better performance is possible if the source accumulator
|
||||
/// matrix is never loaded unnecessarily.
|
||||
CUTLASS_DEVICE
|
||||
bool source_required() const {
|
||||
return !is_zero(params.beta);
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
mad.multiply(params.alpha, accum, output);
|
||||
|
||||
}
|
||||
|
||||
/// Evaluate the functor, without using fragment in the API
|
||||
template <typename ScalarAccum, typename ScalarOutput, int size>
|
||||
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
|
||||
Fragment<ScalarAccum, size> FragAccum;
|
||||
Fragment<ScalarOutput, size> FragOutput;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
FragAccum[i] = accum[i];
|
||||
FragOutput[i] = output[i];
|
||||
}
|
||||
evaluate(FragAccum, FragOutput);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = FragOutput[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
FragmentB_ tmp;
|
||||
mad.multiply(params.beta, old, tmp);
|
||||
mad.multiply_add(params.alpha, accum, tmp, output);
|
||||
}
|
||||
|
||||
/// Evaluate the functor, without using fragment in the API
|
||||
template <typename ScalarAccum, typename ScalarOutput, int size>
|
||||
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
|
||||
Fragment<ScalarAccum, size> FragAccum;
|
||||
Fragment<ScalarOutput, size> FragOutput;
|
||||
Fragment<ScalarOutput, size> FragOld;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
FragAccum[i] = accum[i];
|
||||
FragOutput[i] = output[i];
|
||||
FragOld[i] = old[i];
|
||||
}
|
||||
evaluate(FragAccum, FragOld, FragOutput);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = FragOutput[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,150 +0,0 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/scalar_or_pointer.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments. This is intended to support passing scalars
|
||||
/// either by value from the host or by reference to device-side scalar elements. This is inspired
|
||||
/// by cuBLAS's device pointer mode.
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
|
||||
struct LinearScalingDevicePtr : public LinearScaling<Scalar_, FragmentMultiplyAdd_> {
|
||||
|
||||
/// Linear Scaling class used
|
||||
typedef LinearScaling<Scalar_, FragmentMultiplyAdd_> Base;
|
||||
|
||||
// The scalar.
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// The parameters.
|
||||
class Params {
|
||||
private:
|
||||
/// Alpha scalar
|
||||
detail::ScalarOrPointer<Scalar> alpha_;
|
||||
|
||||
/// Beta sclaar
|
||||
detail::ScalarOrPointer<Scalar> beta_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Scalar alpha,
|
||||
Scalar beta
|
||||
):
|
||||
alpha_(alpha),
|
||||
beta_(beta) {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Scalar const *alpha_ptr,
|
||||
Scalar const *beta_ptr
|
||||
):
|
||||
alpha_(alpha_ptr),
|
||||
beta_(alpha_ptr) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Scalar alpha,
|
||||
Scalar beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_ = beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Scalar const *alpha,
|
||||
Scalar const *beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_= beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
alpha_ = desc.alpha;
|
||||
beta_ = desc.beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Gets the alpha scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar alpha() const {
|
||||
return alpha_;
|
||||
}
|
||||
|
||||
/// Gets the beta scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar beta() const {
|
||||
return beta_;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const& _params) {
|
||||
this->params.alpha = _params.alpha();
|
||||
this->params.beta = _params.beta();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,284 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with
|
||||
the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename EpilogueTraits_>
|
||||
struct MMAEpilogue {
|
||||
/// The traits class.
|
||||
typedef EpilogueTraits_ Traits;
|
||||
|
||||
/// The params.
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// Defines a tiling of the EpilogueTile over the entire threadblock GEMM tile
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename Traits::OutputTile OutputTile;
|
||||
|
||||
/// Accumulators to store in the epilogue
|
||||
typedef typename Traits::Accumulators Accumulators;
|
||||
|
||||
/// A functor to copy a slice of accumulators for a given epilogue iteration
|
||||
typedef typename Traits::SelectAccumulators SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef typename Traits::GlobalLoadStreamC GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef typename Traits::GlobalStoreStreamD GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef typename Traits::SharedStoreStreamD SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
|
||||
|
||||
/// The functor in charge of the math.
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// The scalar type used by the epilogue functor.
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename Traits::ScalarC ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename Traits::ScalarD ScalarD;
|
||||
|
||||
/// The index type.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef typename Traits::GlobalOffset GlobalOffset;
|
||||
|
||||
///
|
||||
typedef typename Traits::GlobalDataLayout GlobalDataLayout;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
/// The dimensions of the GEMM.
|
||||
gemm::GemmCoord problem_size;
|
||||
|
||||
/// Epilogue functor
|
||||
Functor functor;
|
||||
|
||||
// Functor to select a set of accumulators
|
||||
SelectAccumulators select_accumulators;
|
||||
|
||||
|
||||
// Functor to compute the global offset relative to the threadblock for each iteration
|
||||
// of the epilogue.
|
||||
GlobalOffset global_offset;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMAEpilogue(
|
||||
Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Coord<3> const& _problem_size,
|
||||
SelectAccumulators _select_accumulators = SelectAccumulators(),
|
||||
GlobalOffset _global_offset = GlobalOffset()
|
||||
):
|
||||
params(params_),
|
||||
shared_storage(shared_storage_),
|
||||
problem_size(_problem_size),
|
||||
functor(params_.functor),
|
||||
select_accumulators(_select_accumulators),
|
||||
global_offset(_global_offset) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(
|
||||
Accumulators& accumulators,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
|
||||
if (functor.source_required()) {
|
||||
epilogue_with_or_without_beta<true>(accumulators, threadblock_offset, batch_id);
|
||||
}
|
||||
else {
|
||||
epilogue_with_or_without_beta<false>(accumulators, threadblock_offset, batch_id);
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
/// Execute the epilogue.
|
||||
template <bool kSourceRequired>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(
|
||||
Accumulators& accumulators,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
|
||||
/// Global memory mapping function
|
||||
GlobalDataLayout gmem_map_func;
|
||||
|
||||
// Construct shared memory streams
|
||||
SharedStoreStreamD shared_store_stream(
|
||||
params.shared_store_stream_d,
|
||||
shared_storage.reference());
|
||||
|
||||
SharedLoadStreamD shared_load_stream(
|
||||
params.shared_load_stream_d,
|
||||
shared_storage.reference());
|
||||
|
||||
// Map the GEMM problem dimensions into the coordinate system of the output memory
|
||||
Coord<2> gmem_bounds = gmem_map_func(make_Coord(
|
||||
problem_size.m(), // GEMM M - rows
|
||||
problem_size.n())); // GEMM N - columns
|
||||
|
||||
Coord<3> gmem_tile_bounds = make_Coord(
|
||||
problem_size.k(), // GEMM K
|
||||
gmem_bounds[0], // strided
|
||||
gmem_bounds[1]); // contiguous
|
||||
|
||||
// Iterate over the entire Threadblock tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
if (!(h == 0)) {
|
||||
//continue;
|
||||
}
|
||||
|
||||
// Offset in GEMM coordinates
|
||||
gemm::GemmCoord offset_in_gemm = threadblock_offset + global_offset(make_Coord(h, w));
|
||||
|
||||
Coord<2> offset_in_memory = gmem_map_func(
|
||||
make_Coord(
|
||||
offset_in_gemm.m(), // GEMM M - rows
|
||||
offset_in_gemm.n())); // GEMM N - columns
|
||||
|
||||
// Offset in
|
||||
Coord<3> global_tile_offset = make_Coord(
|
||||
offset_in_gemm.k(), // GEMM K
|
||||
offset_in_memory[0], // strided
|
||||
offset_in_memory[1]); // contiguous
|
||||
|
||||
GlobalLoadStreamC global_load_stream(
|
||||
params.load_stream_c,
|
||||
gmem_tile_bounds,
|
||||
global_tile_offset);
|
||||
|
||||
GlobalStoreStreamD global_store_stream(
|
||||
params.store_stream_d,
|
||||
gmem_tile_bounds,
|
||||
global_tile_offset);
|
||||
|
||||
// update C pointer offset based on batch_id and batch_stride_offset
|
||||
global_load_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_C);
|
||||
|
||||
// update D pointer offset based on batch_id and batch_stride_offset
|
||||
global_store_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_D);
|
||||
|
||||
// Load the C matrix into fragment.
|
||||
if (kSourceRequired) {
|
||||
global_load_stream.copy();
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
shared_load_fence();
|
||||
|
||||
// Store accumulator tile to shared memory
|
||||
shared_store_stream.copy(
|
||||
select_accumulators(accumulators, make_Coord(h, w)));
|
||||
|
||||
shared_store_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Load the accumulators back to registers from shared memory.
|
||||
shared_load_stream.copy();
|
||||
shared_load_stream.commit();
|
||||
// Commit the C matrix fragment
|
||||
if (kSourceRequired) {
|
||||
global_load_stream.commit();
|
||||
}
|
||||
|
||||
// Apply epilogue functor
|
||||
if (kSourceRequired) {
|
||||
|
||||
functor.evaluate(shared_load_stream.fragment(),
|
||||
global_load_stream.fragment(),
|
||||
global_store_stream.fragment());
|
||||
}
|
||||
else {
|
||||
|
||||
functor.evaluate(
|
||||
shared_load_stream.fragment(),
|
||||
global_store_stream.fragment());
|
||||
}
|
||||
|
||||
global_store_stream.copy();
|
||||
global_store_stream.commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,360 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// Layout of source matrix in global memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Iterator for loading threadblock-scoped tiles
|
||||
typename LoadIterator_,
|
||||
/// Transformation functor for transforming fragments
|
||||
typename Transformer_,
|
||||
/// Iterator for storing threadblock-scoped tiles to shared memory
|
||||
typename StoreIterator_,
|
||||
/// Number of stores before iterator wraps - zero indicates no wrapping
|
||||
int StageCount>
|
||||
struct MMAGlobalLoadStream {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Identifies the operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
/// The store iterator to write to shared memory.
|
||||
typedef StoreIterator_ StoreIterator;
|
||||
/// Number of stages
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// Predicate vector
|
||||
typedef typename LoadIterator::PredicateVector PredicateVector;
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// The scalar type of the iterator.
|
||||
typedef typename LoadIterator::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::LongIndex LongIndex;
|
||||
/// The tile.
|
||||
typedef typename LoadIterator::Tile Tile;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
|
||||
/// Helper
|
||||
static int const kElementsPerLdg = LoadIterator::Tile::kC;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
/// Stride within a batch of matrix operands
|
||||
LongIndex batch_stride;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
|
||||
// Offset to residue for the last partition
|
||||
Index offset_to_residue_last_partition;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): batch_stride(0), offset_to_residue(0), offset_to_residue_last_partition(0) {}
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
Index _offset_to_residue
|
||||
):
|
||||
batch_stride(0),
|
||||
offset_to_residue(_offset_to_residue),
|
||||
offset_to_residue_last_partition(0),
|
||||
load_iterator(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
) {}
|
||||
|
||||
/// Initializer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
LongIndex batch_stride_,
|
||||
Index offset_to_residue_,
|
||||
Index offset_to_residue_last_partition_) {
|
||||
|
||||
batch_stride = batch_stride_;
|
||||
offset_to_residue = offset_to_residue_;
|
||||
offset_to_residue_last_partition = offset_to_residue_last_partition_;
|
||||
|
||||
return load_iterator.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(static_cast<int>(batch_stride), ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
Index offset_to_residue_) {
|
||||
|
||||
offset_to_residue = offset_to_residue_;
|
||||
return load_iterator.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize(Index offset_to_residue_) {
|
||||
offset_to_residue = offset_to_residue_;
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Index get_offset_to_residue() {
|
||||
if (blockIdx.z == gridDim.z - 1) { //last partition
|
||||
return offset_to_residue_last_partition;
|
||||
}
|
||||
else {
|
||||
return offset_to_residue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Empty shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<
|
||||
typename StoreIterator::Scalar,
|
||||
typename ShapeMul<
|
||||
typename StoreIterator::OperandShape,
|
||||
Shape<kStageCount, 1, 1, 1>
|
||||
>::Shape
|
||||
> ThreadblockTileStorage;
|
||||
|
||||
/// ZipTensorRef to threadblock tiles
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
///! The parameters
|
||||
Params params;
|
||||
|
||||
///! Dimensions of global memory tile
|
||||
Coord<3> threadblock_offset;
|
||||
|
||||
///! Dimensions of multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
|
||||
///! Iterator to load threadblock tiles from global memory
|
||||
LoadIterator load_iterator;
|
||||
|
||||
///! Predicate vector
|
||||
PredicateVector predicates;
|
||||
|
||||
///! The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
|
||||
///! Functor to transform fragments after they have been loaded
|
||||
Transformer transformer;
|
||||
|
||||
///! The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
///! Iterator to store threadblock tiles to shared memory
|
||||
StoreIterator store_iterator;
|
||||
|
||||
///! Counter
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project_coordinate(Coord<3> const &coord, Index d_offset = 0) {
|
||||
bool const kKstrided =
|
||||
gemm::GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
|
||||
|
||||
Coord<3> tile_coord = gemm::ProjectOperand<kOperand, kKstrided>::project(coord);
|
||||
|
||||
return make_Coord(
|
||||
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream(Params const &_params,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block)
|
||||
: params(_params),
|
||||
threadblock_offset(project_coordinate(block)),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(threadblock_tile_ref.data()),
|
||||
stage_index(0) {
|
||||
load_iterator.initialize_predicates(
|
||||
predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Loads the data from global memory
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
|
||||
}
|
||||
|
||||
/// Transform and commit the data to shared memory
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == kStageCount) {
|
||||
store_iterator -= kStageCount;
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a predicate mask for loads during final threadblock tile load iteration
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
// That's the residue!
|
||||
Coord<3> _block_offset = threadblock_offset;
|
||||
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
|
||||
// K-strided
|
||||
_block_offset =
|
||||
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
|
||||
} else {
|
||||
// K-contiguous
|
||||
_block_offset = make_Coord(threadblock_offset[0],
|
||||
threadblock_offset[1],
|
||||
multiplicand_bounds[2] - k / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
/// Move to the residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
Index kResidue = k % kTileK;
|
||||
if (kResidue) {
|
||||
residue(kResidue);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the first tile
|
||||
CUTLASS_DEVICE void rollback(void) {
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
|
||||
int const kBlock = kOperand == GemmOperand::kA
|
||||
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
|
||||
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,201 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Iterators used to load multiplicands from global memory specialized for Volta884 access patterns
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for loading data for congruous access patterns
|
||||
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
|
||||
struct MMAThreadblockCongruousLoad {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout =
|
||||
(Operand == GemmOperand::kA ? MatrixLayout::kColumnMajor : MatrixLayout::kRowMajor);
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename gemm::GemmMultiplicandTraits<Tile_, Operand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Reshapes the threadblock tile by access size
|
||||
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
|
||||
|
||||
/// Shape of tile
|
||||
typedef Shape<1, 4, 8> WarpStoreCoverage;
|
||||
|
||||
/// Shape of tile loaded by each warp per load operation
|
||||
typedef Shape<1, 4, 8> WarpLoadShape;
|
||||
|
||||
//
|
||||
// Load iterator
|
||||
//
|
||||
|
||||
///
|
||||
typedef Shape<1, WarpLoadShape::kH * kWarpCount, WarpLoadShape::kW> Delta;
|
||||
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Rakes warps along contiguous dimensions and strip-mines strided
|
||||
/// dimension.
|
||||
typedef Shape<1,
|
||||
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
|
||||
VectorizedShape::kW / WarpStoreCoverage::kW,
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Functor computing starting offset for each thread
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int lane_k = lane_id / WarpLoadShape::kW;
|
||||
int lane_outer = lane_id % WarpLoadShape::kW;
|
||||
|
||||
Coord<4> offset = make_Coord(0, warp_id * WarpLoadShape::kH + lane_k, lane_outer, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
|
||||
|
||||
/// Load iterator
|
||||
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kH> Iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for loading data for congruous access patterns
|
||||
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
|
||||
struct MMAThreadblockCrosswiseLoad {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout =
|
||||
(Operand == GemmOperand::kA ? MatrixLayout::kRowMajor : MatrixLayout::kColumnMajor);
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename gemm::GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Reshapes the threadblock tile by access size
|
||||
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
|
||||
|
||||
/// Shape of tile
|
||||
typedef Shape<1, 8, 4> WarpStoreCoverage;
|
||||
|
||||
/// Shape of tile loaded by each warp per load operation
|
||||
typedef Shape<1, 8, 4> WarpLoadShape;
|
||||
|
||||
//
|
||||
// Load iterator
|
||||
//
|
||||
|
||||
///
|
||||
typedef Shape<1, WarpLoadShape::kH, WarpLoadShape::kW> Delta;
|
||||
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Rakes warps along contiguous dimensions and strip-mines strided
|
||||
/// dimension.
|
||||
typedef Shape<1,
|
||||
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
|
||||
VectorizedShape::kW / WarpStoreCoverage::kW,
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Functor computing starting offset for each thread
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int lane_k = lane_id % WarpLoadShape::kW;
|
||||
int lane_outer = lane_id / WarpLoadShape::kW;
|
||||
|
||||
Coord<4> offset =
|
||||
make_Coord(0, warp_id * Iterations::kH * WarpLoadShape::kH + lane_outer, lane_k, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
|
||||
|
||||
/// Load iterator
|
||||
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kW> Iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
@ -1,155 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment>,
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
int StageCount = 1>
|
||||
struct MMASharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ Iterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// Element type
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
/// Reference type to a tensor
|
||||
typedef TensorRef<half, 4> TensorRef;
|
||||
|
||||
/// Parameters passed from host
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
|
||||
Iterator iterator;
|
||||
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
|
||||
/// Counts the number of stages
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMASharedLoadStream() : stage_index(0) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMASharedLoadStream(
|
||||
Params const &_params,
|
||||
TensorRef const &ref,
|
||||
Coord<4> warp_offset = make_Coord(0, 0, 0, 0)
|
||||
):
|
||||
iterator(ref.data(), warp_offset), stage_index(0) {
|
||||
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
iterator.load(
|
||||
fetched[step % 2],
|
||||
make_Coord(step + stage_index * Iterator::VectorizedShape::kD, 0, 0, 0)
|
||||
);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE void clear() {
|
||||
fetched[0].clear();
|
||||
fetched[1].clear();
|
||||
transformed[0].clear();
|
||||
transformed[1].clear();
|
||||
}
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == StageCount) {
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
@ -1,129 +0,0 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Helper class defines an object which operates as either a scalar or a pointer. If the pointer
|
||||
/// is non-null, it is dereferenced when the object is accessed.
|
||||
template <typename Scalar_>
|
||||
class ScalarOrPointer {
|
||||
public:
|
||||
/// Underlying scalar type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Scalar value
|
||||
Scalar scalar;
|
||||
|
||||
/// Pointer to use if non null
|
||||
Scalar const *ptr;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(): scalar(0), ptr(nullptr) {}
|
||||
|
||||
/// Object behaves as a scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(Scalar const &val): scalar(val), ptr(nullptr) {}
|
||||
|
||||
/// Object behaves as a scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(Scalar const *ptr_): scalar(0), ptr(ptr_) {}
|
||||
|
||||
/// Returns true if is pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_pointer() const {
|
||||
return bool(ptr);
|
||||
}
|
||||
|
||||
/// Gets the pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *get_ptr() const {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Gets the pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar get_scalar() const {
|
||||
return scalar;
|
||||
}
|
||||
|
||||
/// Assigns to a scalar and sets pointer to nullptr
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer &operator=(Scalar const &scalar_) {
|
||||
scalar = scalar_;
|
||||
ptr = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Assigns to a pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer &operator=(Scalar const *ptr_) {
|
||||
ptr = ptr_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Access the element
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar get() const {
|
||||
if (ptr) {
|
||||
return *ptr;
|
||||
}
|
||||
return scalar;
|
||||
}
|
||||
|
||||
/// Accesses the element
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator Scalar() const {
|
||||
return get();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,172 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of single-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// Whether to specify launch bounds
|
||||
bool kLaunchBounds = true>
|
||||
struct SgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
float,
|
||||
/// The scalar type for B.
|
||||
float,
|
||||
/// The scalar type for C.
|
||||
float,
|
||||
/// The scalar type for D.
|
||||
float,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, float, float, float>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
kLaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, false>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to define SGEMM traits using Launch Bounds
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, true>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SgemmLBTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,257 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/scalar_or_pointer.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter for linera scaling
|
||||
template <typename Scalar_>
|
||||
struct SplitComplexLinearScaling {
|
||||
|
||||
/// Underlying real-valued scalar type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Complex data type
|
||||
typedef platform::complex<Scalar> Complex;
|
||||
|
||||
/// Parameters
|
||||
struct Params {
|
||||
|
||||
/// Alpha
|
||||
Complex alpha;
|
||||
|
||||
/// Beta
|
||||
Complex beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): alpha(0, 0), beta(0, 0) {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Complex const & _alpha, Complex const & _beta) : alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Complex const & _alpha, Complex const & _beta) {
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
SplitComplexLinearScaling() { }
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
SplitComplexLinearScaling(Params const& _params) : params(_params) {}
|
||||
|
||||
/// Method to determine whether the source accumulator matrix C is ever needed.
|
||||
CUTLASS_DEVICE
|
||||
bool source_required() const {
|
||||
return !is_zero(params.beta.real()) || !is_zero(params.beta.imag());
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA, typename FragmentC>
|
||||
CUTLASS_HOST_DEVICE
|
||||
void evaluate(FragmentA const& accum, FragmentC & output) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentA::First::kElements; ++i) {
|
||||
|
||||
// Zip together split-complex accumulator representation for complex arithmetic
|
||||
Complex result = params.alpha * Complex(accum.first[i], accum.second[i]);
|
||||
|
||||
output.first[i] = result.real();
|
||||
output.second[i] = result.imag();
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA, typename FragmentC>
|
||||
CUTLASS_HOST_DEVICE
|
||||
void evaluate(FragmentA const& accum, FragmentC const& old, FragmentC& output) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentA::First::kElements; ++i) {
|
||||
|
||||
// Zip together split-complex representations for complex arithmetic
|
||||
Complex source(old.first[i], old.second[i]);
|
||||
|
||||
Complex result = params.alpha * Complex(accum.first[i], accum.second[i]) + params.beta * source;
|
||||
|
||||
output.first[i] = result.real();
|
||||
output.second[i] = result.imag();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments. This is intended to support passing scalars
|
||||
/// either by value from the host or by reference to device-side scalar elements. This is inspired
|
||||
/// by cuBLAS's device pointer mode.
|
||||
template <typename Scalar_ >
|
||||
struct SplitComplexLinearScalingDevicePtr : public SplitComplexLinearScaling<Scalar_> {
|
||||
|
||||
/// Linear Scaling class used
|
||||
typedef SplitComplexLinearScaling<Scalar_> Base;
|
||||
|
||||
/// Underlying real-valued scalar type
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Complex data type
|
||||
typedef platform::complex<Scalar> Complex;
|
||||
|
||||
/// The parameters.
|
||||
class Params {
|
||||
private:
|
||||
/// Alpha scalar
|
||||
detail::ScalarOrPointer<Complex> alpha_;
|
||||
|
||||
/// Beta scalar
|
||||
detail::ScalarOrPointer<Complex> beta_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Complex alpha,
|
||||
Complex beta
|
||||
):
|
||||
alpha_(alpha),
|
||||
beta_(beta) {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Complex const *alpha_ptr,
|
||||
Complex const *beta_ptr
|
||||
):
|
||||
alpha_(alpha_ptr),
|
||||
beta_(alpha_ptr) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Complex alpha,
|
||||
Complex beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_ = beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Complex const *alpha,
|
||||
Complex const *beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_= beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
alpha_ = desc.alpha;
|
||||
beta_ = desc.beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Gets the alpha scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Complex alpha() const {
|
||||
return alpha_;
|
||||
}
|
||||
|
||||
/// Gets the beta scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Complex beta() const {
|
||||
return beta_;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE SplitComplexLinearScalingDevicePtr(Params const& _params) {
|
||||
this->params.alpha = _params.alpha();
|
||||
this->params.beta = _params.beta();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,107 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template implementing matrix multiply-add operations on fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename ThreadGemmShape_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename ScalarA_,
|
||||
typename ScalarB_,
|
||||
typename ScalarC_,
|
||||
MatrixLayout::Kind kLayout_ = MatrixLayout::kColumnMajor>
|
||||
struct ThreadMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The shape of a thread-leveel matrix multiply accumulate.
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
if(kLayout_ == MatrixLayout::kColumnMajor) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
|
||||
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,447 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
struct swizzleDirection {
|
||||
enum Kind { Boustrophedon, OneDirection };
|
||||
};
|
||||
// helper template function
|
||||
template <enum swizzleDirection::Kind>
|
||||
CUTLASS_DEVICE int getLinearIdx(int groups) {
|
||||
// groupCols is not needed for OneDirection Swizzle
|
||||
return blockIdx.y * gridDim.x + blockIdx.x;
|
||||
}
|
||||
template <>
|
||||
CUTLASS_DEVICE int getLinearIdx<swizzleDirection::Boustrophedon>(int groups) {
|
||||
// reverse blockIdx.x for some columns
|
||||
if ((blockIdx.y / groups) % 2 == 1)
|
||||
return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
|
||||
else
|
||||
return blockIdx.y * gridDim.x + blockIdx.x;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup IdentityBlockSwizzle Identity Block Swizzle
|
||||
@{
|
||||
Block Swizzle provides the mapping logic between a block in the physical memory of Matrix C and
|
||||
Thread Block
|
||||
Identiy Block Swizzle effective maps blocks in leading dimension order (column major) with
|
||||
thread block
|
||||
in leading dimension order (blockIdx.x)
|
||||
blockIdx.z is mapped with batch_count for batched GEMM
|
||||
@}
|
||||
*/
|
||||
struct IdentityBlockSwizzle {
|
||||
/// Ctor. aka ColumnMajorBlockSwizzle<1>
|
||||
CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
/*OutputTile and problem_size are both in KNM order*/
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///get threadblock offset, without considering tha batch dim
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if(is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle
|
||||
groupCols has the effect of controlling the schedulling of thread blocks
|
||||
settings with different groupCols can contribute to the overall performance by affecting L2 cache
|
||||
hit rate
|
||||
|
||||
consider a regular thread block mapping btween matrix C and different thread blocks
|
||||
note that C is column major, and the leading dimension of thread block id is blockIdx.x
|
||||
|
||||
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
|
||||
(blockIdx.x, blockIdx.y)
|
||||
mapping between threadblockID and C matrix:
|
||||
-------------------------------------------------------
|
||||
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are
|
||||
launched in a column major
|
||||
|
||||
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little,
|
||||
-------------------------------------------------------
|
||||
(0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
so in memory, it would apprear that we work on 2 columns at a time rather than 1
|
||||
Note that the index here really represent how each block maps to memory
|
||||
|
||||
A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection>
|
||||
except that every column flips the ordering against the previous one
|
||||
-------------------------------------------------------
|
||||
(0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like
|
||||
-------------------------------------------------------
|
||||
(0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
(0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
*/
|
||||
|
||||
template <int groupCols, enum swizzleDirection::Kind swDirection>
|
||||
struct ColumnMajorBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() {
|
||||
assert(gridDim.z == 1);
|
||||
int linearIdx = getLinearIdx<swDirection>(groupCols);
|
||||
dim3 swizzledBlockIdx;
|
||||
int currGroupCols = groupCols;
|
||||
int prevGroupCols = groupCols;
|
||||
|
||||
if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
|
||||
// last colmuns if gridDim.y is not divisble by groupCols
|
||||
currGroupCols = gridDim.y % groupCols;
|
||||
}
|
||||
|
||||
swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
|
||||
swizzledBlockIdx.y =
|
||||
linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
|
||||
swizzledBlockIdx.z = blockIdx.z;
|
||||
|
||||
return swizzledBlockIdx;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if (is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
consider a regular thread block mapping btween matrix C and different thread blocks
|
||||
note that C is column major, and the leading dimension of thread block id is blockIdx.x
|
||||
|
||||
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
|
||||
(blockIdx.x, blockIdx.y)
|
||||
mapping between threadblockID and C matrix:
|
||||
-------------------------------------------------------
|
||||
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map
|
||||
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) |
|
||||
---------------------------------------------
|
||||
(0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) |
|
||||
-----------------------------------------------
|
||||
(0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
It would aprear in memory we are working on 1 row at a time
|
||||
|
||||
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) |
|
||||
---------------------------------------------
|
||||
(1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
so in memory, it would apprear that we work on 2 rows at a time rather than 1 row
|
||||
Note that the index here really represent how each block maps to memory
|
||||
|
||||
A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection>
|
||||
except that every column flips the ordering against the previous one
|
||||
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) |
|
||||
---------------------------------------------
|
||||
(0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) |
|
||||
-----------------------------------------------
|
||||
(0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) |
|
||||
-----------------------------------------------
|
||||
|
||||
similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) |
|
||||
---------------------------------------------
|
||||
(1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
*/
|
||||
|
||||
template <int groupRows, enum swizzleDirection::Kind swDirection>
|
||||
struct RowMajorBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() {
|
||||
assert(gridDim.z == 1);
|
||||
int linearIdx = getLinearIdx<swDirection>(groupRows);
|
||||
dim3 swizzledBlockIdx;
|
||||
int currGroupRows = groupRows;
|
||||
int prevGroupRows = groupRows;
|
||||
|
||||
if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
|
||||
// last columns
|
||||
currGroupRows = gridDim.y % groupRows;
|
||||
}
|
||||
|
||||
swizzledBlockIdx.x =
|
||||
linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
|
||||
swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
|
||||
swizzledBlockIdx.z = blockIdx.z;
|
||||
|
||||
return swizzledBlockIdx;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1) )
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if (is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,348 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/split_complex_linear_scaling.h"
|
||||
#include "cutlass/util/pair.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Enables treating the accumulators selection as one object
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipSelectAccumulators {
|
||||
|
||||
/// Underlying selection function
|
||||
typedef First_ First;
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Accumulators
|
||||
typedef ZipFragment<
|
||||
typename First::Accumulators,
|
||||
typename Second::Accumulators> Accumulators;
|
||||
|
||||
/// Fragment
|
||||
typedef ZipFragment<
|
||||
typename First::Fragment,
|
||||
typename Second::Fragment> Fragment;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Selects the accumulators for the first part
|
||||
First first;
|
||||
|
||||
/// Selects the accumulators for the second
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_DEVICE
|
||||
ZipSelectAccumulators() { }
|
||||
|
||||
/// Basic constructor
|
||||
CUTLASS_DEVICE
|
||||
ZipSelectAccumulators(First const &_first, Second const &_second): first(_first), second(_second) { }
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
return make_ZipFragment(first(accum.first, idx), second(accum.second, idx));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines epilogue traits for complex-valued mma.sync GEMM
|
||||
template <
|
||||
typename GemmConfig_,
|
||||
typename EpilogueFunctor_ = SplitComplexLinearScaling<typename GemmConfig_::MultiplyAdd::ScalarC>,
|
||||
typename Index_ = int>
|
||||
struct Volta884ComplexGemmEpilogueTraits {
|
||||
|
||||
/// GEMM configuration
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
|
||||
/// Epilogue functor
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// Global memory mapping function
|
||||
typedef MatrixLayout::ColumnMajor GlobalDataLayout;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Long index used for offsets
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// Defines epilogue traits for real-valued Volta884 GEMM epilogue
|
||||
typedef typename Volta884GemmEpilogueTraitsHelper<
|
||||
GemmConfig,
|
||||
Functor,
|
||||
typename GemmConfig::MultiplyAdd::RealMultiplyAdd,
|
||||
Index>::EpilogueTraits RealEpilogueTraits;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename RealEpilogueTraits::OutputTile OutputTile;
|
||||
|
||||
/// The warp-level GEMM tile
|
||||
typedef typename RealEpilogueTraits::WarpGemmTile WarpGemmTile;
|
||||
|
||||
/// Tiling of warp accumulator elements
|
||||
typedef typename RealEpilogueTraits::WarpGemmTile WarpDelta;
|
||||
|
||||
/// Multiply-add operation
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The accumulators fragment type.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// Selects a subset of accumulators for a given epilogue iteration
|
||||
typedef ZipSelectAccumulators<
|
||||
typename RealEpilogueTraits::SelectAccumulators,
|
||||
typename RealEpilogueTraits::SelectAccumulators> SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef cutlass::PredicatedTileLoadStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator
|
||||
>,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::PredicateFunctor,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer
|
||||
>
|
||||
> GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef cutlass::PredicatedTileStoreStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator
|
||||
>,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::PredicateFunctor,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer
|
||||
>
|
||||
> GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef cutlass::TileStoreStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Iterator,
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Iterator
|
||||
>,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Transformer,
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Transformer
|
||||
>
|
||||
> SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef cutlass::TileLoadStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Iterator,
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Iterator
|
||||
>,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Transformer,
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Transformer
|
||||
>
|
||||
> SharedLoadStreamD;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename RealEpilogueTraits::ScalarC ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename RealEpilogueTraits::ScalarD ScalarD;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Cover an entire warp-level tile
|
||||
typedef typename RealEpilogueTraits::Iterations Iterations;
|
||||
|
||||
/// Parameters structure initialized on the host
|
||||
struct Params {
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadStreamC::Params load_stream_c;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreStreamD::Params store_stream_d;
|
||||
|
||||
/// Epilogue functor params
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreStreamD::Params shared_store_stream_d;
|
||||
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
|
||||
/// Stride for C
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride_C;
|
||||
|
||||
/// Stride for D
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {
|
||||
batch_stride_C.first = 0;
|
||||
batch_stride_C.second = 0;
|
||||
|
||||
batch_stride_D.first = 0;
|
||||
batch_stride_D.second = 0;
|
||||
}
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
platform::complex<typename Functor::Scalar> alpha,
|
||||
platform::complex<typename Functor::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd) {
|
||||
|
||||
int result = functor.initialize(alpha, beta);
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
result = load_stream_c.iterator.first.initialize(
|
||||
real_C, real_ldc, real_ldc, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = load_stream_c.iterator.second.initialize(
|
||||
imag_C, imag_ldc, imag_ldc, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
result = store_stream_d.iterator.first.initialize(
|
||||
real_D, real_ldd, real_ldd, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = store_stream_d.iterator.second.initialize(
|
||||
imag_D, imag_ldd, imag_ldd, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
platform::complex<typename Functor::Scalar> alpha,
|
||||
platform::complex<typename Functor::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
LongIndex stride_C_real,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
LongIndex stride_C_imag,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
LongIndex stride_D_real,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd,
|
||||
LongIndex stride_D_imag) {
|
||||
|
||||
batch_stride_C.first = stride_C_real;
|
||||
batch_stride_C.second = stride_C_imag;
|
||||
|
||||
batch_stride_D.first = stride_D_real;
|
||||
batch_stride_D.second = stride_D_imag;
|
||||
|
||||
return initialize(alpha, beta, real_C, real_ldc, imag_C, imag_ldc, real_D, real_ldd, imag_D, imag_ldd);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory buffer used by epilogue
|
||||
typedef ZipTileAllocation<
|
||||
typename RealEpilogueTraits::SharedStorage,
|
||||
typename RealEpilogueTraits::SharedStorage> SharedStorage;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef typename RealEpilogueTraits::GlobalOffset GlobalOffset;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
|
||||
namespace platform {
|
||||
|
||||
/// Here's a helpful arithmetic operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pair<long long, long long> operator*(int s, Pair<long long, long long> _pair) {
|
||||
return Pair<long long, long long>(s * _pair.first, s * _pair.second);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,558 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for complex-valued GEMM targeting Volta's mma.sync
|
||||
instruction.
|
||||
|
||||
At present, it expects split complex representation in global memory in which the real part and
|
||||
imaginary parts of a complex-valued matrices are disjoint (a structure of arrays). This is in
|
||||
contrast with an interleaved complex representation which is an array of structures.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
#include "cutlass/tensor_ref_collection.h"
|
||||
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
|
||||
#include "cutlass/gemm/volta884_multiplicand.h"
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_gemm_traits.h"
|
||||
|
||||
#include "cutlass/gemm/volta884_complex_multiply_add.h"
|
||||
#include "cutlass/gemm/volta884_complex_global_stream.h"
|
||||
#include "cutlass/gemm/volta884_complex_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_complex_gemm_epilogue_traits.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines configuration for Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The source matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The destination matrix type
|
||||
typename ScalarD_,
|
||||
/// Number of stages in shared memory
|
||||
int StageCount,
|
||||
/// Enables or disables launch bounds
|
||||
bool LaunchBounds>
|
||||
struct Volta884ComplexGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The threadblock tile size
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
Volta884ComplexMultiplyAdd<WarpGemmShape_,
|
||||
LayoutA,
|
||||
TransformA,
|
||||
half,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
half,
|
||||
Accumulator_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
8,
|
||||
/// The number of scalars per STS for A.
|
||||
8,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
8,
|
||||
/// The number of scalars per STS for B.
|
||||
8,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of stages in shared memory.
|
||||
StageCount,
|
||||
/// If true, separate mainloop is instantiated
|
||||
true,
|
||||
/// If true, compute residue in prolog
|
||||
false,
|
||||
/// Launch bounds not used
|
||||
LaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines components of Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The input matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The output matrix type type.
|
||||
typename ScalarD_,
|
||||
/// Number of buffers in shared memory to use
|
||||
int StageCount,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = SplitComplexLinearScaling<Accumulator_>,
|
||||
/// Enables or disables launch bounds
|
||||
bool LaunchBounds = false
|
||||
>
|
||||
struct Volta884ComplexGemmTraits {
|
||||
|
||||
/// This is insane.
|
||||
typedef Volta884ComplexGemmTraits<
|
||||
LayoutA,
|
||||
TransformA,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
OutputTile_,
|
||||
WarpGemmShape_,
|
||||
Accumulator_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
StageCount,
|
||||
EpilogueFunctor_,
|
||||
LaunchBounds> This;
|
||||
|
||||
/// The actual device-side GEMM
|
||||
typedef GemmMainloop<This> KernelClass;
|
||||
|
||||
/// Layout of multiplicand A matrix
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// If true, A operand is conjugated
|
||||
static MatrixTransform::Kind const kTransformA = TransformA;
|
||||
|
||||
/// Layout of multiplicand B matrix
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// If true, B operand is conjugated
|
||||
static MatrixTransform::Kind const kTransformB = TransformB;
|
||||
|
||||
/// Dimensions of threadblock tile (concept Shape)
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// Shape of warp-level accumulators
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Multiplicand A scalar type
|
||||
typedef half ScalarA;
|
||||
|
||||
/// Multiplicand B scalar type
|
||||
typedef half ScalarB;
|
||||
|
||||
/// Data type of internal accumulator
|
||||
typedef Accumulator_ Accumulator;
|
||||
|
||||
/// Data type of input accumulator matrix operand
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Data type of output accumulator matrix operand
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// Shape of individual mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Tile size for an individual warp-level multiply-add
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Defines properties about GEMM needed by host code
|
||||
typedef Volta884ComplexGemmConfig<
|
||||
kLayoutA,
|
||||
kTransformA,
|
||||
kLayoutB,
|
||||
kTransformB,
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
Accumulator,
|
||||
ScalarC,
|
||||
ScalarD,
|
||||
StageCount,
|
||||
LaunchBounds>
|
||||
GemmConfig;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Long index type
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// Partitioning of threadblock into warps
|
||||
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Number of warps per threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Defines iterators for A matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandA;
|
||||
|
||||
/// Defines iterators for B matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandB;
|
||||
|
||||
//
|
||||
// GemmTraits mandatory type definitions
|
||||
//
|
||||
|
||||
/// Maps hardware threadblocks to logical partitions of the GEMM
|
||||
typedef IdentityBlockSwizzle BlockSwizzle;
|
||||
|
||||
/// Clears accumulators
|
||||
typedef ClearAccumulators<ScalarC> ClearAccumulators;
|
||||
|
||||
/// Loads multiplicands from global memory
|
||||
typedef GlobalLoadStreamPair<
|
||||
Volta884ComplexGlobalLoadStream<GemmOperand::kA,
|
||||
kLayoutA,
|
||||
typename MultiplicandA::LoadIterator,
|
||||
Copy<typename MultiplicandA::LoadIterator::Fragment>,
|
||||
typename MultiplicandA::StoreIterator,
|
||||
StageCount>,
|
||||
Volta884ComplexGlobalLoadStream<GemmOperand::kB,
|
||||
kLayoutB,
|
||||
typename MultiplicandB::LoadIterator,
|
||||
Copy<typename MultiplicandB::LoadIterator::Fragment>,
|
||||
typename MultiplicandB::StoreIterator,
|
||||
StageCount>,
|
||||
GemmConfig::kResidueInProlog >
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
|
||||
/// Shared memory storage for mainloop phase
|
||||
union MainLoopStorage {
|
||||
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
};
|
||||
|
||||
/// Loads multiplicands from shared memory
|
||||
typedef SharedStreamPair<
|
||||
Volta884ComplexSharedLoadStream<typename MultiplicandA::WarpLoadIterator,
|
||||
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
|
||||
StageCount>,
|
||||
Volta884ComplexSharedLoadStream<typename MultiplicandB::WarpLoadIterator,
|
||||
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
|
||||
StageCount> >
|
||||
SharedStream;
|
||||
|
||||
// Multiply-add object specialized for Volta mma.sync
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
#if 0
|
||||
/// Naive epilogue for updating the output matrix
|
||||
typedef Volta884ComplexNaiveEpilogue<ScalarC,
|
||||
typename MultiplicandA::WarpDelta,
|
||||
typename MultiplyAdd::Iterations>
|
||||
Epilogue;
|
||||
|
||||
#else
|
||||
|
||||
/// Efficient epilogue
|
||||
typedef MMAEpilogue<
|
||||
Volta884ComplexGemmEpilogueTraits<GemmConfig, EpilogueFunctor_>
|
||||
> Epilogue;
|
||||
|
||||
#endif
|
||||
|
||||
/// Tensor reference to A multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarA, 2>,
|
||||
TensorRef<ScalarA, 2>
|
||||
> TensorRefA;
|
||||
|
||||
/// Tensor reference to B multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarB, 2>,
|
||||
TensorRef<ScalarB, 2>
|
||||
> TensorRefB;
|
||||
|
||||
/// Tensor reference to C multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarC, 2>,
|
||||
TensorRef<ScalarC, 2>
|
||||
> TensorRefC;
|
||||
|
||||
/// Tensor reference to D multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarD, 2>,
|
||||
TensorRef<ScalarD, 2>
|
||||
> TensorRefD;
|
||||
|
||||
/// gemm::ProblemDesc<>
|
||||
typedef GemmDesc<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
TensorRefD,
|
||||
float
|
||||
> GemmDesc;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// PartitionK_range
|
||||
int partitionK_range;
|
||||
|
||||
/// The params for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// The params for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Initialize the Params struct
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
platform::complex<typename Epilogue::Scalar> alpha,
|
||||
ScalarA const* real_A,
|
||||
Index real_lda,
|
||||
ScalarA const* imag_A,
|
||||
Index imag_lda,
|
||||
ScalarB const* real_B,
|
||||
Index real_ldb,
|
||||
ScalarB const* imag_B,
|
||||
Index imag_ldb,
|
||||
platform::complex<typename Epilogue::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd) {
|
||||
|
||||
problem_size = make_Coord(k, n, m, 1);
|
||||
|
||||
partitionK_range = problem_size.k();
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Initialize global load streams
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), 0),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), 0)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), 0),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), 0)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
return epilogue.initialize(
|
||||
alpha,
|
||||
beta,
|
||||
real_C,
|
||||
real_ldc,
|
||||
imag_C,
|
||||
imag_ldc,
|
||||
real_D,
|
||||
real_ldd,
|
||||
imag_D,
|
||||
imag_ldd
|
||||
);
|
||||
}
|
||||
|
||||
/// Initialize the Params struct
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
platform::complex<typename Epilogue::Scalar> alpha,
|
||||
ScalarA const* real_A,
|
||||
Index real_lda,
|
||||
LongIndex batch_stride_A_real,
|
||||
ScalarA const* imag_A,
|
||||
Index imag_lda,
|
||||
LongIndex batch_stride_A_imag,
|
||||
ScalarB const* real_B,
|
||||
Index real_ldb,
|
||||
LongIndex batch_stride_B_real,
|
||||
ScalarB const* imag_B,
|
||||
Index imag_ldb,
|
||||
LongIndex batch_stride_B_imag,
|
||||
platform::complex<typename Epilogue::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
LongIndex batch_stride_C_real,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
LongIndex batch_stride_C_imag,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
LongIndex batch_stride_D_real,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd,
|
||||
LongIndex batch_stride_D_imag,
|
||||
int batch_count) {
|
||||
|
||||
problem_size = make_Coord(k, n, m, batch_count);
|
||||
partitionK_range = problem_size.k();
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Initialize global load streams
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), batch_stride_A_real),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), batch_stride_A_imag)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), batch_stride_B_real),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), batch_stride_B_imag)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
return epilogue.initialize(
|
||||
alpha,
|
||||
beta,
|
||||
real_C,
|
||||
real_ldc,
|
||||
batch_stride_C_real,
|
||||
imag_C,
|
||||
imag_ldc,
|
||||
batch_stride_C_imag,
|
||||
real_D,
|
||||
real_ldd,
|
||||
batch_stride_D_real,
|
||||
imag_D,
|
||||
imag_ldd,
|
||||
batch_stride_D_imag
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage
|
||||
union SharedStorage {
|
||||
/// Storage required during mainloop phase
|
||||
MainLoopStorage main_loop;
|
||||
|
||||
/// Shared storage needed for epilogue
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (StageCount < 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,315 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing
|
||||
to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/util/pair.h"
|
||||
|
||||
#include "cutlass/gemm/mma_global_stream.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// Layout of source matrix in global memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Iterator for loading threadblock-scoped tiles
|
||||
typename LoadIterator_,
|
||||
/// Transformation functor for transforming fragments
|
||||
typename Transformer_,
|
||||
/// Iterator for storing threadblock-scoped tiles to shared memory
|
||||
typename StoreIterator_,
|
||||
/// Number of stores before iterator wraps - zero indicates no wrapping
|
||||
int StageCount>
|
||||
struct Volta884ComplexGlobalLoadStream {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Identifies the operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
|
||||
/// Load-store stream for real-valued matrices
|
||||
typedef MMAGlobalLoadStream<Operand, Layout, LoadIterator_, Transformer_, StoreIterator_, StageCount> RealLoadStoreStream;
|
||||
|
||||
/// Loads a pair of real-valued fragments
|
||||
typedef ZipTileIterator<LoadIterator_, LoadIterator_> LoadIterator;
|
||||
|
||||
/// Zips a pair of transformers
|
||||
typedef ZipConvert<Transformer_, Transformer_> Transformer;
|
||||
|
||||
/// Stores a pair of real-valued ragments
|
||||
typedef ZipTileIterator<StoreIterator_, StoreIterator_> StoreIterator;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// Predicate vector
|
||||
typedef typename RealLoadStoreStream::PredicateVector PredicateVector;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// Index type
|
||||
typedef typename RealLoadStoreStream::Index Index;
|
||||
|
||||
/// Long index type
|
||||
typedef typename RealLoadStoreStream::LongIndex LongIndex;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Matrix reference
|
||||
typedef ZipTensorRef<
|
||||
TensorRefBatchStrided<half const, 2>,
|
||||
TensorRefBatchStrided<half const, 2> > SourceTensorRef;
|
||||
|
||||
/// Helper
|
||||
static int const kElementsPerLdg = LoadIterator::First::Tile::kC;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Source tensor reference
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride;
|
||||
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(SourceTensorRef const &ref, Index _offset_to_residue) {
|
||||
initialize(ref, _offset_to_residue);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SourceTensorRef const &ref, Index _offset_to_residue) {
|
||||
|
||||
batch_stride.first = ref.first.tensor_stride;
|
||||
batch_stride.second = ref.second.tensor_stride;
|
||||
|
||||
offset_to_residue = _offset_to_residue;
|
||||
load_iterator.first.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.first.at().data(),
|
||||
make_Coord(ref.first.at().stride(0) * kElementsPerLdg, ref.first.at().stride(0), kElementsPerLdg)
|
||||
)
|
||||
);
|
||||
load_iterator.second.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.second.at().data(),
|
||||
make_Coord(ref.second.at().stride(0) * kElementsPerLdg, ref.second.at().stride(0), kElementsPerLdg)
|
||||
)
|
||||
);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Empty shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<
|
||||
typename RealLoadStoreStream::StoreIterator::Scalar,
|
||||
typename ShapeMul<
|
||||
typename RealLoadStoreStream::StoreIterator::OperandShape,
|
||||
Shape<kStageCount, 1, 1, 1>
|
||||
>::Shape
|
||||
> RealThreadblockTileStorage;
|
||||
|
||||
/// Threadblock tile allocation
|
||||
typedef ZipTileAllocation<
|
||||
RealThreadblockTileStorage,
|
||||
RealThreadblockTileStorage
|
||||
> ThreadblockTileStorage;
|
||||
|
||||
/// Reference to ThreadblockTileStorage
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
///! The parameters
|
||||
Params params;
|
||||
|
||||
///! Dimensions of global memory tile
|
||||
Coord<3> threadblock_offset;
|
||||
|
||||
///! Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
|
||||
///! Iterator to load threadblock tiles from global memory
|
||||
LoadIterator load_iterator;
|
||||
|
||||
///! Predicate vector
|
||||
PredicateVector predicates;
|
||||
|
||||
///! The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
|
||||
///! Functor to transform fragments after they have been loaded
|
||||
Transformer transformer;
|
||||
|
||||
///! The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
///! Iterator to store threadblock tiles to shared memory
|
||||
StoreIterator store_iterator;
|
||||
|
||||
///! Counter
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream(Params const &_params,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block)
|
||||
: params(_params),
|
||||
threadblock_offset(RealLoadStoreStream::project_coordinate(block)),
|
||||
multiplicand_bounds(RealLoadStoreStream::project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(threadblock_tile_ref),
|
||||
stage_index(0) {
|
||||
|
||||
// initialize predicates used to guard loads
|
||||
load_iterator.initialize_predicates(
|
||||
predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Loads the data from global memory
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
|
||||
}
|
||||
|
||||
/// Transform and commit the data to shared memory
|
||||
CUTLASS_DEVICE void commit() {
|
||||
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == kStageCount) {
|
||||
store_iterator -= kStageCount;
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a predicate mask for loads during final threadblock tile load iteration
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
// That's the residue!
|
||||
Coord<3> _block_offset = threadblock_offset;
|
||||
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
|
||||
// K-strided
|
||||
_block_offset =
|
||||
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
|
||||
} else {
|
||||
// K-contiguous
|
||||
_block_offset = make_Coord(threadblock_offset[0],
|
||||
threadblock_offset[1],
|
||||
multiplicand_bounds[2] - k / LoadIterator::First::Tile::kC);
|
||||
}
|
||||
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {}
|
||||
|
||||
CUTLASS_DEVICE void rollback() {}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.first.add_pointer_offset(params.batch_stride.first * batch_id);
|
||||
load_iterator.second.add_pointer_offset(params.batch_stride.second * batch_id);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,319 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
|
||||
for complex-valued data types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/util/complex.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/gemm/volta884_multiply_add.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// Layout of multiplicand A
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// Data type of multiplicand A
|
||||
typename ScalarA_,
|
||||
/// Layout of multiplicand B
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// Data type of multiplicand B
|
||||
typename ScalarB_,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC_,
|
||||
/// If true, A operand is conjugated
|
||||
bool ConjugateA = false,
|
||||
/// If true, B operand is conjugated
|
||||
bool ConjugateB = false,
|
||||
/// If true, infinite results are saturated to +-MAX_FLOAT
|
||||
bool SatFinite = false>
|
||||
struct Volta884ComplexMultiplyAdd {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Shape of an individual warp-wide mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Shape of a warp-level matrix multiply operation
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
|
||||
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
|
||||
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
|
||||
"WarpTile must be a multiple of InterleavedTileShape.");
|
||||
|
||||
/// Layout of A multiplicand
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
static MatrixTransform::Kind const kTransformA = TransformA;
|
||||
|
||||
/// Layout of B multiplicand
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
static MatrixTransform::Kind const kTransformB = TransformB;
|
||||
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// If true, infinite results are saturated to +-MAX_FLOAT
|
||||
static bool const kSatFinite = SatFinite;
|
||||
|
||||
/// Hard-coded comptue type supported on Volta
|
||||
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
|
||||
|
||||
/// Underlying matrix multiply-add operator
|
||||
typedef Volta884MultiplyAdd<WarpGemmShape,
|
||||
kLayoutA,
|
||||
ScalarA,
|
||||
kLayoutB,
|
||||
ScalarB,
|
||||
ScalarC>
|
||||
RealMultiplyAdd;
|
||||
|
||||
/// Fragment definition for A multiplicand
|
||||
typedef ZipFragment<typename RealMultiplyAdd::FragmentA, typename RealMultiplyAdd::FragmentA>
|
||||
FragmentA;
|
||||
|
||||
/// Fragment definition for B multiplicand
|
||||
typedef ZipFragment<typename RealMultiplyAdd::FragmentB, typename RealMultiplyAdd::FragmentB>
|
||||
FragmentB;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef ZipFragment<typename RealMultiplyAdd::Accumulators,
|
||||
typename RealMultiplyAdd::Accumulators>
|
||||
Accumulators;
|
||||
|
||||
/// Number of mma.sync operations performed. See Volta884MultiplyAdd::Iterations for details.
|
||||
typedef typename RealMultiplyAdd::Iterations Iterations;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
|
||||
FragmentB const& B,
|
||||
Accumulators const& C,
|
||||
Accumulators& D) {
|
||||
RealMultiplyAdd op;
|
||||
|
||||
// complex-valued multiply-add
|
||||
op.multiply_add(A.first, B.first, C.first, D.first);
|
||||
op.multiply_add(A.first, B.second, C.second, D.second, kTransformB == MatrixTransform::kConjugate);
|
||||
op.multiply_add(A.second, B.first, C.second, D.second, kTransformA == MatrixTransform::kConjugate);
|
||||
op.multiply_add(A.second, B.second, C.first, D.first,
|
||||
!((kTransformA == MatrixTransform::kConjugate) ^ (kTransformB == MatrixTransform::kConjugate)));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Complex-valued epilogue
|
||||
template <typename Accumulator, typename WarpDelta, typename Iterations>
|
||||
struct Volta884ComplexNaiveEpilogue {
|
||||
/// Accumulator data type
|
||||
typedef Accumulator ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef Accumulator ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef Accumulator Scalar;
|
||||
|
||||
/// Real-valued epilogue
|
||||
typedef Volta884NaiveEpilogue<Accumulator, WarpDelta, Iterations> RealEpilogue;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Parameters for the real-valued part
|
||||
typename RealEpilogue::Params real;
|
||||
|
||||
/// Parameters for the imaginary-valued part
|
||||
typename RealEpilogue::Params imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE Params() {}
|
||||
|
||||
/// Constructs from params object
|
||||
CUTLASS_HOST_DEVICE Params(typename RealEpilogue::Params const& _real,
|
||||
typename RealEpilogue::Params const& _imag)
|
||||
: real(_real), imag(_imag) {}
|
||||
|
||||
/// Construct from pointers
|
||||
CUTLASS_HOST_DEVICE Params(ScalarC* _real, int _ldr, ScalarC* _imag, int _ldi)
|
||||
: real(_real, _ldr), imag(_imag, _ldi) {}
|
||||
|
||||
/// Construct from pointers
|
||||
CUTLASS_HOST_DEVICE Params(
|
||||
platform::complex<Scalar> const &alpha,
|
||||
platform::complex<Scalar> const &beta,
|
||||
ScalarC const *real_C,
|
||||
int real_ldc,
|
||||
ScalarC const *imag_C,
|
||||
int imag_ldc,
|
||||
ScalarD *real_D,
|
||||
int real_ldd,
|
||||
ScalarD *imag_D,
|
||||
int imag_ldd
|
||||
):
|
||||
real(real_D, real_ldd, alpha.real(), beta.real()),
|
||||
imag(imag_D, imag_ldd, alpha.real(), beta.real()) { }
|
||||
|
||||
/// Initializer method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
platform::complex<Scalar> const &alpha,
|
||||
platform::complex<Scalar> const &beta,
|
||||
ScalarC const *real_C,
|
||||
int real_ldc,
|
||||
ScalarC const *imag_C,
|
||||
int imag_ldc,
|
||||
ScalarD *real_D,
|
||||
int real_ldd,
|
||||
ScalarD *imag_D,
|
||||
int imag_ldd
|
||||
) {
|
||||
|
||||
real = typename RealEpilogue::Params(real_D, real_ldd, alpha.real(), beta.real());
|
||||
imag = typename RealEpilogue::Params(imag_D, imag_ldd, alpha.real(), beta.real());
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Accumulator fragment definition
|
||||
typedef ZipFragment<
|
||||
typename RealEpilogue::Accumulators,
|
||||
typename RealEpilogue::Accumulators> Accumulators;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Epilogue for real part
|
||||
RealEpilogue real;
|
||||
|
||||
/// Epilogue for imaginary part
|
||||
RealEpilogue imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(
|
||||
Params const& _params, Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(ScalarC* _real,
|
||||
int _ldr,
|
||||
ScalarC* _imag,
|
||||
int _ldi,
|
||||
Coord<3> const& _problem_size = make_Coord(1024,
|
||||
1024,
|
||||
1024))
|
||||
: real(_real, _ldr, _problem_size), imag(_imag, _ldi, _problem_size) {}
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024,
|
||||
1024,
|
||||
1024))
|
||||
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) {
|
||||
C.first.clear();
|
||||
C.second.clear();
|
||||
}
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
real.load(C.first, threadblock_offset);
|
||||
imag.load(C.second, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
real.store(C.first, threadblock_offset);
|
||||
imag.store(C.second, threadblock_offset);
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
real.store(C.first, threadblock_offset);
|
||||
imag.store(C.second, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,152 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment>,
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
int StageCount = 1>
|
||||
struct Volta884ComplexSharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ RealIterator;
|
||||
|
||||
/// Zips two real-valued iterators together
|
||||
typedef ZipTileIterator<RealIterator, RealIterator> Iterator;
|
||||
|
||||
/// The transformer.
|
||||
typedef Transformer_ RealTransformer;
|
||||
|
||||
/// Zips two transfoerms
|
||||
typedef ZipConvert<RealTransformer, RealTransformer> Transformer;
|
||||
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// Reference type
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<half, 4>,
|
||||
TensorRef<half, 4>
|
||||
> TensorRef;
|
||||
|
||||
/// Parameters passed from host
|
||||
struct Params { };
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
|
||||
Iterator iterator;
|
||||
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
|
||||
/// Counts the number of stages
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexSharedLoadStream() : stage_index(0) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexSharedLoadStream(Params const &_params,
|
||||
TensorRef const &ref)
|
||||
: iterator(ref), stage_index(0) {}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
iterator.load(fetched[step % 2],
|
||||
make_Coord(step + stage_index * Iterator::First::VectorizedShape::kD, 0, 0, 0));
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == StageCount) {
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,771 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/tile_stream.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Abstraction to select accumulators from an accumulator tile for each iteration fo the epilogue
|
||||
template <typename WarpGemmShape, typename WarpDelta, typename Scalar>
|
||||
struct Volta884SelectAccumulators;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Selects accumulators from Volta mma.sync.F32 layout
|
||||
template <typename WarpGemmShape_, typename WarpDelta_>
|
||||
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, float> {
|
||||
/// Shape of the warp-level matrix multiply operation
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Describes tiling of warp elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Data type of scalar
|
||||
typedef float Scalar;
|
||||
|
||||
//
|
||||
// Derived types and constants
|
||||
//
|
||||
|
||||
/// (Actual) number of accumulators held by each individual thread
|
||||
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
|
||||
|
||||
/// Accumulators fragment
|
||||
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
|
||||
|
||||
/// Number of warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Interleaved mma.sync shape
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Hard-coded for FP32 layouts
|
||||
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 4> Elements;
|
||||
|
||||
/// Number of elements
|
||||
static int const kElements = ShapeCount<Elements>::kCount;
|
||||
|
||||
/// Slice of accumulators
|
||||
typedef Fragment<Scalar, kElements> Fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
Fragment frag;
|
||||
|
||||
static int const kAccumPerOp = 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Elements::kH; ++j) {
|
||||
|
||||
// selects the 32x32 tile
|
||||
Coord<2> tile_32x32 = make_Coord(idx[0] / 8, j);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Elements::kW; ++i) {
|
||||
Coord<2> mma_op = make_Coord(((idx[0] >> 1) & 1), i / 2);
|
||||
|
||||
int element = ((i & 1) << 1) | (idx[0] & 1) | (idx[0] & 4);
|
||||
|
||||
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
|
||||
|
||||
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return frag;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Selects accumulators from Volta mma.sync.F16 layout
|
||||
template <typename WarpGemmShape_, typename WarpDelta_>
|
||||
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, half> {
|
||||
/// Shape of the warp-level matrix multiply operation
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Describes tiling of warp elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Data type of accumulator elements
|
||||
typedef half Scalar;
|
||||
|
||||
//
|
||||
// Derived types and constants
|
||||
//
|
||||
|
||||
/// (Actual) number of accumulators held by each individual thread
|
||||
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
|
||||
|
||||
/// Accumulators fragment
|
||||
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
|
||||
|
||||
/// Number of warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Interleaved mma.sync shape
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Hard-coded for FP16 layouts
|
||||
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 2> Elements;
|
||||
|
||||
/// Number of elements
|
||||
static int const kElements = ShapeCount<Elements>::kCount;
|
||||
|
||||
/// Slice of accumulators
|
||||
typedef Fragment<Scalar, kElements> Fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
Fragment frag;
|
||||
|
||||
static int const kAccumPerOp = 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Elements::kH; ++j) {
|
||||
|
||||
// selects the 32x32 tile
|
||||
Coord<2> tile_32x32 = make_Coord(idx[0] / 16, j);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Elements::kW; ++i) {
|
||||
|
||||
Coord<2> mma_op = make_Coord(((idx[0] >> 2) & 1), i & 1);
|
||||
|
||||
int element = (idx[0] & 3) | ((idx[0] >> 1) & 4);
|
||||
|
||||
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
|
||||
|
||||
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return frag;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
//
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize,
|
||||
/// The accumulators fragment type - implies accumulator layout
|
||||
typename Accumulators_>
|
||||
struct Volta884EpilogueGlobalTileTraits;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Global tile traits specialized for Volta mma.sync.F32 layout
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize>
|
||||
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, float> {
|
||||
/// Shape of warp-scoped GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Structure of MMA
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Access size of input/output elements
|
||||
static int const kAccessSize = AccessSize;
|
||||
|
||||
/// Scalar type of accumulators - used to imply accumulator layout, not the data
|
||||
typedef float Accumulators;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
//typedef Shape<2, 2, 1, 1> Iterations;
|
||||
|
||||
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
|
||||
static int const kMmaQuadPairWidth = 16;
|
||||
|
||||
/// Hard-coded pitch between warp tiles
|
||||
static int const kInterleavedTileWidth = 32;
|
||||
|
||||
/// Number of actual threads
|
||||
static int const kThreadCount = (WarpDelta::kH * WarpDelta::kW) * kWarpSize;
|
||||
|
||||
/// Shape of the tile
|
||||
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<2 * WarpDelta::kH,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1> Iterations;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef Shape<kMmaQuadPairWidth, 2, WarpDelta::kW * kWarpSize, 1> Delta;
|
||||
|
||||
/// Number of warps in threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Custom thread-offset function
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE
|
||||
Coord<4> operator()() {
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * Delta::kH, offset_w, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Global tile traits specialized for Volta mma.sync.F16 layout
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize>
|
||||
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, half> {
|
||||
/// Shape of warp-scoped GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Structure of MMA tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Access size of input/output elements
|
||||
static int const kAccessSize = AccessSize;
|
||||
|
||||
/// Scalar type of accumulators - used to imply accumulator layout, not the data
|
||||
typedef half Accumulators;
|
||||
|
||||
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
|
||||
static int const kMmaQuadPairWidth = 16;
|
||||
|
||||
/// Hard-coded pitch between warp tiles
|
||||
static int const kInterleavedTileWidth = 32;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreadCount = kWarpSize * WarpDelta::kH * WarpDelta::kW;
|
||||
|
||||
/// Shape of the tile
|
||||
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<
|
||||
1,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1> Iterations;
|
||||
|
||||
|
||||
/// Delta between thread-level accesses
|
||||
typedef typename platform::conditional<
|
||||
kThreadCount >= Tile::kW,
|
||||
Shape<1, kMmaQuadPairWidth * (kThreadCount / Tile::kW), 1, 1>,
|
||||
Shape<1, kMmaQuadPairWidth, kThreadCount, 1>
|
||||
>::type Delta;
|
||||
|
||||
/// Number of warps in threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Custom thread-offset function
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE
|
||||
Coord<4> operator()() {
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * kMmaQuadPairWidth, offset_w, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Global offset functor for Volta884 epilogues
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename WarpDelta, typename AccumulatorType>
|
||||
struct Volta884EpilogueGlobalOffset;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue. Specialized for Volta mma.sync.F32
|
||||
template <typename WarpDelta>
|
||||
struct Volta884EpilogueGlobalOffset<WarpDelta, float> {
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Coord<3> operator()(Coord<2> const &iteration) const {
|
||||
|
||||
int h = iteration[0];
|
||||
|
||||
// C++ needs a better way to express bit swizzling
|
||||
int h_offset = ((h & 1) | ((h & 2) << 1) | (((h & 4) >> 2) * 8) |
|
||||
(((h & 8) >> 3) * WarpDelta::kH * MmaTileShape::kH));
|
||||
|
||||
return make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue. Specialized for Volta mma.sync.F16
|
||||
template <typename WarpDelta>
|
||||
struct Volta884EpilogueGlobalOffset<WarpDelta, half> {
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Coord<3> operator()(Coord<2> const &iteration) const {
|
||||
|
||||
int h = iteration[0];
|
||||
|
||||
// C++ needs a better way to express bit swizzling
|
||||
int h_offset = (h & 15) | (h & 16) * 2 * WarpDelta::kH;
|
||||
|
||||
Coord<3> offset = make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Epilogue traits for Volta884 epilogue
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue traits for Volta884 GEMMs
|
||||
template <
|
||||
/// The threadblock GEMM tile
|
||||
typename OutputTile_,
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// The accumulators fragment type.
|
||||
typename Accumulators_,
|
||||
/// Selects a slice of accumulators
|
||||
typename SelectAccumulators_,
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typename GlobalLoadStreamC_,
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typename GlobalStoreStreamD_,
|
||||
/// The stream to store matrix product to shared memory
|
||||
typename SharedStoreStreamD_,
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typename SharedLoadStreamD_,
|
||||
/// The functor computing an element-wise operation on the matrix product
|
||||
typename Functor_,
|
||||
/// Global memory mapping function
|
||||
typename GlobalDataLayout_ = MatrixLayout::ColumnMajor,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct Volta884EpilogueTraits {
|
||||
/// The output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// The warp-level GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp accumulator elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// The accumulators fragment type.
|
||||
typedef Accumulators_ Accumulators;
|
||||
|
||||
/// Selects a subset of accumulators for a given epilogue iteration
|
||||
typedef SelectAccumulators_ SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef GlobalLoadStreamC_ GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef GlobalStoreStreamD_ GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef SharedStoreStreamD_ SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef SharedLoadStreamD_ SharedLoadStreamD;
|
||||
|
||||
/// The functor computing an element-wise operation on the matrix product
|
||||
typedef Functor_ Functor;
|
||||
|
||||
/// Global memory mapping function
|
||||
typedef GlobalDataLayout_ GlobalDataLayout;
|
||||
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename GlobalLoadStreamC::Iterator::Scalar ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename GlobalStoreStreamD::Iterator::Scalar ScalarD;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
static bool const kFp32Arrangement = sizeof(typename SelectAccumulators::Scalar) == 4;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Number of columns of accumulators stored/loaded depends on the accumulator arrangement
|
||||
static int const kColumnsPerWarp = (kFp32Arrangement ? 4 : 2);
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Cover an entire warp-level tile
|
||||
typedef Shape<1,
|
||||
WarpGemmTile::kH / kColumnsPerWarp, // iterates over 32x32 accumulator tiles along N dimension
|
||||
1, // iterates over 32x32 accumulator tiles along M dimension
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Skew is needed to reduce bank conflicts to SMEM - this shape depends on accumulator layout
|
||||
typedef Shape<1,
|
||||
WarpDelta::kH * kColumnsPerWarp, // multiple columns in the gemm N dimension
|
||||
WarpDelta::kW * WarpGemmTile::kW + kSkew, // rows in the gemm M dimension
|
||||
1
|
||||
> EpilogueTileAllocation;
|
||||
|
||||
/// Parameters structure initialized on the host
|
||||
struct Params {
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadStreamC::Params load_stream_c;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreStreamD::Params store_stream_d;
|
||||
|
||||
/// Epilogue functor params
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreStreamD::Params shared_store_stream_d;
|
||||
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
|
||||
///
|
||||
long long int batch_stride_C;
|
||||
|
||||
///
|
||||
long long int batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Helper constructor taking pointer, stride for source and destination matrices and functor
|
||||
/// params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarD *ptr_D,
|
||||
int ldd,
|
||||
ScalarC const *ptr_C,
|
||||
int ldc,
|
||||
typename Functor::Params _functor = Functor::Params())
|
||||
: load_stream_c(), store_stream_d(), functor(_functor) {}
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
batch_stride_C = desc.batch_stride_C;
|
||||
batch_stride_D = desc.batch_stride_D;
|
||||
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = load_stream_c.iterator.initialize(
|
||||
desc.C.data(), desc.C.leading_dim(), desc.C.leading_dim(), 1
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return store_stream_d.iterator.initialize(
|
||||
desc.D.data(), desc.D.leading_dim(), desc.D.leading_dim(), 1
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory buffer used by epilogue
|
||||
typedef TileAllocation<
|
||||
typename SharedStoreStreamD::Iterator::Scalar,
|
||||
EpilogueTileAllocation> SharedStorage;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef Volta884EpilogueGlobalOffset<WarpDelta, typename SelectAccumulators::Scalar> GlobalOffset;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Volta884 Epilogue helper
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits, typename AccumulatorType>
|
||||
struct Volta884EpiloguePredicateFunctor;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
|
||||
template <typename TileTraits>
|
||||
struct Volta884EpiloguePredicateFunctor<TileTraits, float> {
|
||||
/// Dimensions of the bounding volume
|
||||
Coord<3> bounds;
|
||||
|
||||
/// Constructs a predicate functor given the bounds of a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
|
||||
|
||||
/// Computes the predicate given the logical position of an access
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
|
||||
return
|
||||
(iteration[0] * TileTraits::Delta::kD + iteration[1] * TileTraits::Delta::kH +
|
||||
offset[1] < bounds[1]) &&
|
||||
(iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
|
||||
template <typename TileTraits>
|
||||
struct Volta884EpiloguePredicateFunctor<TileTraits, half> {
|
||||
/// Dimensions of the bounding volume
|
||||
Coord<3> bounds;
|
||||
|
||||
/// Constructs a predicate functor given the bounds of a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
|
||||
|
||||
/// Computes the predicate given the logical position of an access
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
|
||||
return iteration[1] * TileTraits::Delta::kH + offset[1] < bounds[1] &&
|
||||
iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Volta884 Epilogue helper
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to define the traits for a Volta884 Epilogue
|
||||
template <
|
||||
typename GemmConfig_,
|
||||
typename EpilogueFunctor_,
|
||||
typename MultiplyAdd_ = typename GemmConfig_::MultiplyAdd,
|
||||
typename Index_ = int>
|
||||
struct Volta884GemmEpilogueTraitsHelper {
|
||||
|
||||
/// Configuration object defining GEMM properties
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
|
||||
/// Warp-level tile
|
||||
typedef typename GemmConfig::AccumulatorsPerWarp WarpGemmShape;
|
||||
|
||||
/// Warp delta
|
||||
typedef typename ShapeDiv<
|
||||
typename GemmConfig::OutputTile,
|
||||
WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Thread-block scoped tile
|
||||
typedef typename cutlass::ShapeMul<
|
||||
WarpGemmShape,
|
||||
WarpDelta
|
||||
>::Shape OutputTile;
|
||||
|
||||
/// Multiply-add operation
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
|
||||
/// Epilogue functor
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// Traits for global tile access
|
||||
typedef cutlass::gemm::Volta884EpilogueGlobalTileTraits<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
1,
|
||||
typename MultiplyAdd::ScalarC
|
||||
> EpilogueGlobalTileTraits;
|
||||
|
||||
/// Iterator to load a slice of the C matrix from global memory
|
||||
typedef cutlass::TileLoadIterator<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename GemmConfig::ScalarC,
|
||||
cutlass::IteratorAdvance::kW,
|
||||
cutlass::MemorySpace::kGlobal
|
||||
> TileLoadIteratorC;
|
||||
|
||||
/// Conversion from C data type to accumulator data type
|
||||
typedef Convert<
|
||||
typename TileLoadIteratorC::Fragment,
|
||||
Fragment<typename MultiplyAdd::ScalarC, TileLoadIteratorC::Fragment::kElements>
|
||||
> ConvertSourceFragment;
|
||||
|
||||
/// Iterator to store a slice of the D matrix to global memory
|
||||
typedef cutlass::TileStoreIterator<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename GemmConfig::ScalarD,
|
||||
cutlass::IteratorAdvance::kW,
|
||||
cutlass::MemorySpace::kGlobal
|
||||
> TileStoreIteratorD;
|
||||
|
||||
/// Conversion from accumulator data type to D data type
|
||||
typedef Convert<
|
||||
Fragment<typename MultiplyAdd::ScalarC, TileStoreIteratorD::Fragment::kElements>,
|
||||
typename TileStoreIteratorD::Fragment
|
||||
> ConvertDestinationFragment;
|
||||
|
||||
/// Defines traits for an epilogue of a Volta884 GEMM
|
||||
typedef cutlass::gemm::Volta884EpilogueTraits<
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::Accumulators,
|
||||
cutlass::gemm::Volta884SelectAccumulators<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>,
|
||||
cutlass::PredicatedTileLoadStream<
|
||||
TileLoadIteratorC,
|
||||
cutlass::gemm::Volta884EpiloguePredicateFunctor<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename MultiplyAdd::ScalarC>,
|
||||
ConvertSourceFragment
|
||||
>,
|
||||
cutlass::PredicatedTileStoreStream<
|
||||
TileStoreIteratorD,
|
||||
cutlass::gemm::Volta884EpiloguePredicateFunctor<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename MultiplyAdd::ScalarC>,
|
||||
ConvertDestinationFragment
|
||||
>,
|
||||
cutlass::TileStoreStream<
|
||||
cutlass::gemm::Volta884EpilogueSharedStoreIterator<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>
|
||||
>,
|
||||
cutlass::TileLoadStream<
|
||||
cutlass::gemm::Volta884EpilogueSharedLoadIterator<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC,
|
||||
1,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>
|
||||
>,
|
||||
Functor
|
||||
> EpilogueTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,585 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "cutlass/gemm/volta884_multiplicand.h"
|
||||
#include "cutlass/gemm/volta884_multiply_add.h"
|
||||
#include "cutlass/gemm/mma_global_stream.h"
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/mma_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_mainloop.h"
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines configuration for Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The source matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The destination matrix type
|
||||
typename ScalarD_,
|
||||
/// Number of stages in shared memory
|
||||
int StageCount,
|
||||
|
||||
/// If true, kernel is launched with CUDA launch bounds specified
|
||||
bool kLaunchBounds = true,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
bool kResidueSeparate = true,
|
||||
/// Is residue performed in prologue?
|
||||
bool kResidueInProlog = false>
|
||||
struct Volta884GemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The threadblock tile size
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
Volta884MultiplyAdd<WarpGemmShape_,
|
||||
LayoutA,
|
||||
half,
|
||||
LayoutB,
|
||||
half,
|
||||
Accumulator_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
8,
|
||||
/// The number of scalars per STS for A.
|
||||
8,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
8,
|
||||
/// The number of scalars per STS for B.
|
||||
8,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of stages in shared memory.
|
||||
StageCount,
|
||||
/// If true, separate mainloop is instantiated
|
||||
kResidueSeparate,
|
||||
/// If true, compute residue in prolog
|
||||
kResidueInProlog,
|
||||
/// Launch bounds not used
|
||||
kLaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines components of Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The input matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The output matrix type type.
|
||||
typename ScalarD_,
|
||||
/// Number of buffers in shared memory to use
|
||||
int StageCount,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<Accumulator_>,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = IdentityBlockSwizzle,
|
||||
/// Selectively enables launch bounds
|
||||
bool LaunchBounds = false
|
||||
>
|
||||
struct Volta884GemmTraits {
|
||||
/// This traits
|
||||
typedef Volta884GemmTraits<
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
OutputTile_,
|
||||
WarpGemmShape_,
|
||||
Accumulator_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
StageCount,
|
||||
EpilogueFunctor_,
|
||||
BlockSwizzle_,
|
||||
LaunchBounds> This_;
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
|
||||
|
||||
/// Layout of multiplicand A matrix
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// Layout of multiplicand B matrix
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// Dimensions of threadblock tile (concept Shape)
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// Shape of warp-level accumulators
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Multiplicand A scalar type
|
||||
typedef half ScalarA;
|
||||
|
||||
/// Multiplicand B scalar type
|
||||
typedef half ScalarB;
|
||||
|
||||
/// Data type of internal accumulator
|
||||
typedef Accumulator_ Accumulator;
|
||||
|
||||
/// Data type of input accumulator matrix operand
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Data type of output accumulator matrix operand
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// Shape of individual mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Tile size for an individual warp-level multiply-add
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Defines properties about GEMM needed by host code
|
||||
typedef Volta884GemmConfig<kLayoutA,
|
||||
kLayoutB,
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
Accumulator,
|
||||
ScalarC,
|
||||
ScalarD,
|
||||
StageCount,
|
||||
LaunchBounds>
|
||||
GemmConfig;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Partitioning of threadblock into warps
|
||||
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Number of warps per threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Defines iterators for A matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandA;
|
||||
|
||||
/// Defines iterators for B matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandB;
|
||||
|
||||
//
|
||||
// GemmTraits mandatory type definitions
|
||||
//
|
||||
|
||||
/// Maps hardware threadblocks to logical partitions of the GEMM
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
|
||||
/// Clears accumulators
|
||||
typedef ClearAccumulators<ScalarC> ClearAccumulators;
|
||||
|
||||
/// Loads multiplicands from global memory
|
||||
typedef GlobalLoadStreamPair<
|
||||
MMAGlobalLoadStream<GemmOperand::kA,
|
||||
kLayoutA,
|
||||
typename MultiplicandA::LoadIterator,
|
||||
Copy<typename MultiplicandA::LoadIterator::Fragment>,
|
||||
typename MultiplicandA::StoreIterator,
|
||||
StageCount>,
|
||||
MMAGlobalLoadStream<GemmOperand::kB,
|
||||
kLayoutB,
|
||||
typename MultiplicandB::LoadIterator,
|
||||
Copy<typename MultiplicandB::LoadIterator::Fragment>,
|
||||
typename MultiplicandB::StoreIterator,
|
||||
StageCount>,
|
||||
GemmConfig::kResidueInProlog >
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
union MainLoopStorage {
|
||||
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
};
|
||||
|
||||
/// Loads multiplicands from shared memory
|
||||
typedef SharedStreamPair<
|
||||
MMASharedLoadStream<typename MultiplicandA::WarpLoadIterator,
|
||||
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
|
||||
StageCount>,
|
||||
MMASharedLoadStream<typename MultiplicandB::WarpLoadIterator,
|
||||
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
|
||||
StageCount> >
|
||||
SharedStream;
|
||||
|
||||
// Multiply-add object specialized for Volta mma.sync
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
#if 0
|
||||
/// Naive epilogue for updating the output matrix
|
||||
typedef cutlass::gemm::Volta884NaiveEpilogue<ScalarC,
|
||||
typename MultiplicandA::WarpDelta,
|
||||
typename MultiplyAdd::Iterations>
|
||||
Epilogue;
|
||||
#else
|
||||
|
||||
/// Efficient epilogue
|
||||
typedef cutlass::gemm::MMAEpilogue<
|
||||
typename Volta884GemmEpilogueTraitsHelper<
|
||||
GemmConfig,
|
||||
EpilogueFunctor_
|
||||
>::EpilogueTraits
|
||||
> Epilogue;
|
||||
|
||||
#endif
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The K range for every partition except the last one
|
||||
int partitionK_range;
|
||||
|
||||
/// The params for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// The params for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE Params(GemmDesc_ const& desc) {
|
||||
initialize(desc);
|
||||
}
|
||||
|
||||
/// Initialize the Params struct
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
// Problem size
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// there is no partitionK in the default case
|
||||
partitionK_range = problem_size[0];
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue_last_partition = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
// Initialize parameters objects for
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
desc.A,
|
||||
desc.batch_stride_A,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
desc.B,
|
||||
desc.batch_stride_B,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition);
|
||||
|
||||
// The epilogue.
|
||||
epilogue.initialize(desc);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Helper to construct a GEMM params using a BLAS-like API
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a batched GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
make_Coord(k, n, m, batch_count),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
batch_stride_A,
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
batch_stride_B,
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
batch_stride_C,
|
||||
TensorRef<ScalarD, 2>(d_d, ldd),
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_, Index partitionK_multiple_ = 1) {
|
||||
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
|
||||
// the problem_size of each batch is (lastK_size, n, m)
|
||||
// add more comments here
|
||||
// the k range for every batch excpet the last one
|
||||
|
||||
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
|
||||
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
|
||||
// the k range of the last batch
|
||||
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
|
||||
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
|
||||
|
||||
assert((partitionK_range % partitionK_multiple_) == 0);
|
||||
assert(partitionK_range > 0);
|
||||
assert((lastK_range % partitionK_multiple_) == 0);
|
||||
assert(lastK_range > 0);
|
||||
|
||||
int k_size = lastK_range;
|
||||
int lda = partitonK_desc.A.stride(0);
|
||||
int ldb = partitonK_desc.B.stride(0);
|
||||
int ldc = partitonK_desc.C.stride(0);
|
||||
int ldd = partitonK_desc.D.stride(0);
|
||||
int n = partitonK_desc.problem_size.n();
|
||||
|
||||
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
|
||||
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
|
||||
long long int batch_stride_C = ldc * n;
|
||||
long long int batch_stride_D = ldd * n;
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
|
||||
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
|
||||
partitonK_desc.alpha,
|
||||
partitonK_desc.A,
|
||||
batch_stride_A,
|
||||
partitonK_desc.B,
|
||||
batch_stride_B,
|
||||
partitonK_desc.beta,
|
||||
partitonK_desc.C,
|
||||
batch_stride_C,
|
||||
partitonK_desc.D,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A,
|
||||
desc.batch_stride_A,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B,
|
||||
desc.batch_stride_B,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
|
||||
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage
|
||||
union SharedStorage {
|
||||
/// Storage required during mainloop phase
|
||||
MainLoopStorage main_loop;
|
||||
|
||||
/// Shared storage needed for epilogue
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (StageCount < 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -1,298 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
#include "cutlass/gemm/mma_global_tile.h"
|
||||
#include "cutlass/gemm/volta884_shared_tile.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies warp tile shape
|
||||
typename WarpTile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp tiles
|
||||
typename WarpDelta_>
|
||||
struct Volta884Multiplicand;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for A.column_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kW>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for B.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kH>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for A.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kW>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for B.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kH>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,704 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// Layout of A multiplicand
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarA_,
|
||||
/// Layout of B multiplicand
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarB_,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC_>
|
||||
struct Volta884MultiplyAdd {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Shape of an individual warp-wide Volta mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Shape of a warp-level matrix multiply operation
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
|
||||
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
|
||||
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
|
||||
"WarpTile must be a multiple of InterleavedTileShape.");
|
||||
|
||||
/// Layout of A multiplicand
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
/// Layout of B multiplicand
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Hard-coded comptue type supported on Volta
|
||||
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
|
||||
|
||||
/// Defines a warp-level matrix multiply-accumulate operation performed by a warp.
|
||||
//
|
||||
// The layout is as follows. The entire warp performs a 64x64x4 GEMM using Volta mma.sync macros
|
||||
// arranged as a 2x2 tile of adjacent, 32x32x4 matrix products. These are implemented as a
|
||||
// 2x2 arrangement of spatially interleaved Volta mma.sync macros.
|
||||
//
|
||||
// The Iterations shape maps to the following dimensions of the above warp-level GEMM:
|
||||
//
|
||||
// kC: number of rows of Volta mma.sync macros in 32x32x4 tile
|
||||
// kW: number of columns of Volta mma.sync macros in 32x32x4 tile
|
||||
// kH: number of rows of 32x32x4 macros in larger 64x64x4 tile
|
||||
// kD: number of columns of 32x32x4 macros in larger 64x64x4 tile
|
||||
//
|
||||
// A column-major ordering would arrange C and H as the inner-most loops, with W and D as the
|
||||
// outer-most.
|
||||
//
|
||||
typedef Shape<WarpTile::kH / InterleavedTileShape::kH,
|
||||
WarpTile::kW / InterleavedTileShape::kW,
|
||||
InterleavedTileShape::kH / InstructionShape::kH,
|
||||
InterleavedTileShape::kW / InstructionShape::kW>
|
||||
Iterations;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kMultElementsPerInst = 4;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for A multiplicand
|
||||
typedef Fragment<ScalarA, Iterations::kH * Iterations::kC * kMultElementsPerInst> FragmentA;
|
||||
|
||||
/// Fragment definition for B multiplicand
|
||||
typedef Fragment<ScalarB, Iterations::kW * Iterations::kD * kMultElementsPerInst> FragmentB;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884MultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = (-)a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
|
||||
FragmentB const& B,
|
||||
Accumulators const& C,
|
||||
Accumulators& D,
|
||||
bool negate = false) {
|
||||
// Guard conditional needed for __hneg2
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) { // Outer column
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) { // Inner column
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h_raw = 0; h_raw < Iterations::kH; ++h_raw) { // Outer row
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c_raw = 0; c_raw < Iterations::kC; ++c_raw) { // Inner row
|
||||
|
||||
int op_col = (w + Iterations::kW * d);
|
||||
|
||||
// Column-major serpentine sequence to maximize reuse of B operand.
|
||||
int h = h_raw;
|
||||
int c = c_raw;
|
||||
|
||||
if (op_col & 1) {
|
||||
h = Iterations::kH - h_raw - 1;
|
||||
c = Iterations::kC - c_raw - 1;
|
||||
}
|
||||
|
||||
int op_row = (c + Iterations::kC * h);
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
ScalarA operand_A[kMultElementsPerInst];
|
||||
|
||||
reinterpret_cast<uint64_t&>(operand_A[0]) =
|
||||
reinterpret_cast<uint64_t const&>(A[op_row * kMultElementsPerInst]);
|
||||
|
||||
if (negate) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMultElementsPerInst; i += 2) {
|
||||
reinterpret_cast<half2&>(operand_A[i]) =
|
||||
__hneg2(reinterpret_cast<half2 const&>(A[op_row * kMultElementsPerInst + i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Issue a Volta mma.sync instruction
|
||||
arch::mma<InstructionShape,
|
||||
kLayoutA,
|
||||
ScalarA,
|
||||
kLayoutB,
|
||||
ScalarB,
|
||||
ScalarC,
|
||||
kComputeType>(
|
||||
|
||||
operand_A, //&A[op_row * kMultElementsPerInst],
|
||||
&B[op_col * kMultElementsPerInst],
|
||||
&C[op_idx * kAccumElementsPerInst],
|
||||
&D[op_idx * kAccumElementsPerInst]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <=750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Accumulator, typename WarpDelta, typename Iterations>
|
||||
struct Volta884NaiveEpilogue;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Naive epilogue specialized for f32 accumulators - may be considered authoritative mapping of
|
||||
/// accumulators to mma.sync operations.
|
||||
template <typename WarpDelta_, typename Iterations_>
|
||||
struct Volta884NaiveEpilogue<float, WarpDelta_, Iterations_> {
|
||||
/// Accumulator data type
|
||||
typedef float ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef float ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef float Scalar;
|
||||
|
||||
/// Delta among warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Number of Volta mma.sync operations
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Pointer to output matrix
|
||||
ScalarC* ptr;
|
||||
|
||||
/// stride
|
||||
int ldm;
|
||||
|
||||
/// Scalar alpha
|
||||
float alpha;
|
||||
|
||||
/// Scalar beta
|
||||
float beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
|
||||
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
|
||||
ptr = _ptr;
|
||||
ldm = _ldm;
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
|
||||
ldm = desc.D.leading_dim();
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Helper used to compute initial offset for each thread
|
||||
struct InitialOffset {
|
||||
int row_offset;
|
||||
int col_offset;
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InitialOffset() {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
// compute initial offset
|
||||
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
|
||||
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 1);
|
||||
int thread_col_offset = (quadpair_col * 2) * 8 + (lane_id & 2);
|
||||
|
||||
row_offset = warp_row_offset + thread_row_offset;
|
||||
col_offset = warp_col_offset + thread_col_offset;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// Problem size
|
||||
Coord<3> problem_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr,
|
||||
int _ldm,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_ptr, _ldm), problem_size(_problem_size) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// loads accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = (reg & 2) + c * 4;
|
||||
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// store out accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = (reg & 2) + c * 4;
|
||||
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
op_ptr[tr + tc * params.ldm] =
|
||||
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
|
||||
params.beta * op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Naive epilogue specialized for f16 accumulators - may be considered authoritative mapping of
|
||||
/// accumulators to mma.sync operations.
|
||||
template <typename WarpDelta_, typename Iterations_>
|
||||
struct Volta884NaiveEpilogue<half, WarpDelta_, Iterations_> {
|
||||
/// Accumulator data type
|
||||
typedef half ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef half ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Delta among warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Number of Volta mma.sync operations
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Pointer to output matrix
|
||||
ScalarC* ptr;
|
||||
|
||||
/// stride
|
||||
int ldm;
|
||||
|
||||
/// Scalar alpha
|
||||
half alpha;
|
||||
|
||||
/// Scalar beta
|
||||
half beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
|
||||
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
|
||||
ptr = _ptr;
|
||||
ldm = _ldm;
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
|
||||
ldm = desc.D.leading_dim();
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Helper used to compute initial offset for each thread
|
||||
struct InitialOffset {
|
||||
int row_offset;
|
||||
int col_offset;
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InitialOffset() {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
// compute initial offset
|
||||
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
|
||||
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
|
||||
int thread_col_offset = (quadpair_col * 2) * 8;
|
||||
|
||||
row_offset = warp_row_offset + thread_row_offset;
|
||||
col_offset = warp_col_offset + thread_col_offset;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// Problem size
|
||||
Coord<3> problem_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params)
|
||||
: params(_params), problem_size(make_Coord(1024, 1024, 1024)) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr, int _ldm)
|
||||
: params(_ptr, _ldm), problem_size(make_Coord(1024, 1024, 1024)) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) { C.clear(); }
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// loads accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = c * 4;
|
||||
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// store out accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = c * 4;
|
||||
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
op_ptr[tr + tc * params.ldm] =
|
||||
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
|
||||
params.beta * op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,142 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Warp-scoped shared memory load iterators
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Iterator to store a thread-block scoped fragment to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta>
|
||||
struct Volta884WarpMultiplicandLoadIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Fully-specialized implementations extracted in separate headers.
|
||||
//
|
||||
|
||||
#include "cutlass/gemm/volta884_shared_tile_contiguous.h"
|
||||
#include "cutlass/gemm/volta884_shared_tile_crosswise.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Epilogue shared memory iterators
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores an accumulator fragment to shared memory
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Data type of mma.sync accumulator - this is used to infer layout.
|
||||
typename Accumulator_>
|
||||
struct Volta884EpilogueSharedStoreIterator;
|
||||
|
||||
/// Loads an accumulator fragment from shared memory
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of scalar elements loaded
|
||||
int AccessSize_,
|
||||
/// Data type of mma.sync accumulator - this is used to infer layout.
|
||||
typename Accumulator_>
|
||||
struct Volta884EpilogueSharedLoadIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
//
|
||||
// Partially-specialized implementations extracted in separate header.
|
||||
//
|
||||
|
||||
#include "cutlass/gemm/volta884_shared_tile_epilogue.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,974 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Congruous loading
|
||||
//
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Store iterator specialized for A.column_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int k_idx = warp_id;
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int vec_idx = lane_id & 3;
|
||||
int vec_col = (vec_idx / 2);
|
||||
|
||||
int t4t3 = (lane_id >> 3);
|
||||
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
|
||||
|
||||
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
|
||||
|
||||
Coord<4> offset = make_Coord(k_idx, col_rotate, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between stores
|
||||
typedef Shape<WarpCount, 1, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a store iterator
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
// Compute initial thread offset
|
||||
Coord<4> offset = offset_func();
|
||||
|
||||
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each store
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
int idx = w + Iterations::kW * h;
|
||||
|
||||
int row = idx * 4;
|
||||
|
||||
Coord<4> sts_offset =
|
||||
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
|
||||
params.pointer,
|
||||
params.stride.template dot<int>(sts_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
|
||||
params.pointer +=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
|
||||
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
|
||||
return increment(count);
|
||||
}
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
|
||||
params.pointer -=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to subsequent tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
|
||||
|
||||
/// Decrements to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
|
||||
return decrement(count);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments in the K dimension
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
|
||||
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
store(fragment, offset);
|
||||
return increment();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile specialized for A.column_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta_>
|
||||
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Shape of warp-tile matrix operation
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Hard-coded tile shape
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Two SMEM read pointers are needed
|
||||
static int const kPointerCount = (WarpDelta::kW == 1 ? 2 : 1);
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
/// Compute thread offset coordinate for each pointer
|
||||
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx = 0) const {
|
||||
// Determine the warp's reading location within the SMEM tile
|
||||
int warp_id = ((threadIdx.x >> 5) % WarpDelta::kW);
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int vec_row = (lane_id >> 4);
|
||||
int vec_col = ((lane_id & 4) >> 2);
|
||||
|
||||
int tile_row = pointer_idx * 2 + vec_row;
|
||||
|
||||
// Column rotation function
|
||||
int t_col = (vec_col * 4);
|
||||
if (pointer_idx == 1 || (WarpDelta::kW > 1 && (warp_id & 1))) {
|
||||
vec_row |= 2;
|
||||
}
|
||||
|
||||
t_col = t_col | ((lane_id & 3) ^ vec_row);
|
||||
|
||||
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between acceses
|
||||
typedef typename platform::conditional<WarpDelta::kW == 1,
|
||||
Shape<1, 0, 0, 0>,
|
||||
Shape<1, 2 * WarpDelta::kW, 0, 0> >::type Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, WarpTile::kW / InterleavedTileShape::kW, 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, kAccessSize>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar const *pointer;
|
||||
|
||||
/// SMEM strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
// A.column requires two SMEM pointers.
|
||||
// Because Params only supplies a base pointer and strides, there is no usual params
|
||||
// data member. Instead, it is used to initialize the following.
|
||||
|
||||
/// Pointer to SMEM allocation.
|
||||
Scalar const *pointer[kPointerCount];
|
||||
|
||||
/// SMEM strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a load iterator
|
||||
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: stride(_params.stride) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < kPointerCount; ++idx) {
|
||||
Coord<4> offset = offset_func(idx);
|
||||
|
||||
pointer[idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE void load(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each load
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Pointers mapped to Iterations::kH dimension
|
||||
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
|
||||
|
||||
Coord<4> lds_offset =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
|
||||
_pointer,
|
||||
stride.template dot<int>(lds_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments to next K-index
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
load(fragment, offset);
|
||||
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Store iterator specialized for B.row_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int k_idx = warp_id;
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int vec_idx = lane_id & 3;
|
||||
int vec_col = (vec_idx / 2);
|
||||
|
||||
int t4t3 = (lane_id >> 3);
|
||||
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
|
||||
|
||||
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
|
||||
|
||||
Coord<4> offset = make_Coord(k_idx, col_rotate , t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between stores
|
||||
typedef Shape<WarpCount, 1, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a store iterator
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
// Compute initial offset for each thread
|
||||
Coord<4> offset = offset_func();
|
||||
|
||||
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each store
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
int idx = w + Iterations::kW * h;
|
||||
int row = idx * 4;
|
||||
|
||||
Coord<4> sts_offset =
|
||||
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Index _offset = params.stride.template dot<int>(sts_offset + offset);
|
||||
|
||||
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
|
||||
params.pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
|
||||
params.pointer +=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
|
||||
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
|
||||
return increment(count);
|
||||
}
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
|
||||
params.pointer -=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to subsequent tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
|
||||
|
||||
/// Decrements to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
|
||||
return decrement(count);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments in the K dimension
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
|
||||
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
store(fragment, offset);
|
||||
return increment();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile specialized for B.row_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta_>
|
||||
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Shape of warp-tile matrix operation
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Hard-coded tile shape
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
/// Computes the initial offset
|
||||
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx) const {
|
||||
// Determine the warp's reading location within the SMEM tile
|
||||
int warp_id = ((threadIdx.x >> 5) / WarpDelta::kW);
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int vec_row = (lane_id >> 4);
|
||||
int vec_col = ((lane_id & 8) >> 3);
|
||||
|
||||
int tile_row = pointer_idx * 2 + vec_row;
|
||||
|
||||
// Column rotation function
|
||||
int t_col = (vec_col * 4);
|
||||
if (pointer_idx == 1 || (WarpDelta::kH > 1 && (warp_id & 1))) {
|
||||
vec_row |= 2;
|
||||
}
|
||||
|
||||
t_col = t_col | ((lane_id & 3) ^ vec_row);
|
||||
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Delta between accesses
|
||||
typedef typename platform::conditional<WarpDelta::kH == 1,
|
||||
Shape<1, 0, 0, 0>,
|
||||
Shape<1, 2 * WarpDelta::kH, 0, 0> >::type Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, WarpTile::kH / InterleavedTileShape::kH, 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Number of SMEM read pointers needed
|
||||
static int const kPointerCount = (WarpDelta::kH == 1 ? 2 : 1);
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar const *pointer[kPointerCount];
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a load iterator
|
||||
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: stride(_params.stride) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
Coord<4> offset = offset_func(ptr_idx);
|
||||
|
||||
pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void load(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each load
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Pointers mapped to Iterations::kH dimension
|
||||
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
|
||||
|
||||
Coord<4> lds_offset =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
|
||||
_pointer,
|
||||
stride.template dot<int>(lds_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments to next K-index
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
load(fragment, offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,629 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
|
||||
DO NOT INCLUDE THIS FILE DIRECTLY.
|
||||
|
||||
This file is intended to be included by <cutlass/gemm/volta884_shared_tile.h> and defines
|
||||
partial specializations for templates specified therein.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations for FP32 accumulator layouts
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP32 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_>
|
||||
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, float> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef float Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
// Host-side params
|
||||
struct Params {};
|
||||
|
||||
/// Access size
|
||||
static int const kAccessSize = 1;
|
||||
|
||||
/// Skew elements to ensure conflict free stores
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Four element fragment
|
||||
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 4, 1> Iterations;
|
||||
|
||||
/// Delta separated by two elements
|
||||
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 2, 1> Delta;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar *pointer;
|
||||
|
||||
/// Stride in shared memory
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()), strides(make_Coord(1, WarpDelta::kW * WarpGemmTile::kW + kSkew, 1, 1)) {
|
||||
|
||||
int warp_id = (threadIdx.x / kWarpSize);
|
||||
int lane_id = (threadIdx.x % kWarpSize);
|
||||
|
||||
Coord<4> warp_idx = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0);
|
||||
|
||||
Coord<4> warp_base = warp_idx * make_Coord(0, 4, MmaTileShape::kW, 0);
|
||||
|
||||
Coord<4> thread_idx = make_Coord(0,
|
||||
(((lane_id >> 1) & 4) | (lane_id & 2)) >> 1,
|
||||
(lane_id & 1) | ((lane_id >> 1) & 8) | ((lane_id << 2) & 16),
|
||||
0);
|
||||
|
||||
int offset = strides.template dot<int>(warp_base + thread_idx);
|
||||
|
||||
pointer += offset;
|
||||
}
|
||||
|
||||
/// Store to the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &fragment) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)), pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
|
||||
CUTLASS_DEVICE
|
||||
void store_post_increment(Fragment const &fragment) { store(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP32 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of elements loaded per access
|
||||
int AccessSize_>
|
||||
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, float> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef float Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Number of elements accessed at once
|
||||
static int const kAccessSize = AccessSize_;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Total participating warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Total participating threads
|
||||
static int const kThreadCount = kWarpCount * kWarpSize;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// This tile is to be strip-mined with a swizzling function
|
||||
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<2 * WarpDelta::kH,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef Shape<2, 1, kThreadCount, 1> Delta;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment of elements to load
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()),
|
||||
strides(make_Coord((WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
kAccessSize,
|
||||
1)) {
|
||||
// strip-mine this tile
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kW, offset_h * Delta::kH, offset_w, 0);
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &fragment) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), pointer, _offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
|
||||
CUTLASS_DEVICE
|
||||
void load_post_increment(Fragment &fragment) { load(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations for FP16 accumulator layouts
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP16 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_>
|
||||
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, half> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef half Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
/// Dimensions of contiguous 32x32x4 Volta's mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Accumulator fragment
|
||||
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 2, 1> Iterations;
|
||||
|
||||
/// Delta separated by two elements
|
||||
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 4, 1> Delta;
|
||||
|
||||
/// Access size
|
||||
static int const kAccessSize = 1;
|
||||
|
||||
/// Skew elements to ensure conflict free stores
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar *pointer;
|
||||
|
||||
/// Stride in shared memory
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()), strides(make_Coord(1, WarpGemmTile::kW * WarpDelta::kW + kSkew, 1, 1)) {
|
||||
|
||||
int warp_id = (threadIdx.x / kWarpSize);
|
||||
int lane_id = (threadIdx.x % kWarpSize);
|
||||
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
|
||||
int thread_col_offset = quadpair_col;
|
||||
|
||||
Coord<4> thread_idx = make_Coord(0, thread_col_offset, thread_row_offset, 0);
|
||||
|
||||
Coord<4> warp_base = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0) *
|
||||
make_Coord(0, 2, kWarpSize, 0);
|
||||
Coord<4> offset = warp_base + thread_idx;
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Store to the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &fragment) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(fragment[w + Iterations::kW * d]),
|
||||
pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
|
||||
CUTLASS_DEVICE
|
||||
void store_post_increment(Fragment const &fragment) { store(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP16 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of elements loaded per access
|
||||
int AccessSize_>
|
||||
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, half> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef half Accumulator;
|
||||
|
||||
/// Number of elements accessed at once
|
||||
static int const kAccessSize = AccessSize_;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// This tile is to be strip-mined with a swizzling function
|
||||
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW / kAccessSize, kAccessSize>
|
||||
Tile;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Total participating warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreadCount = kWarpSize * kWarpCount;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Delta between thread-level accesses
|
||||
typedef typename platform::conditional<kThreadCount >= Tile::kW,
|
||||
Shape<1, (kThreadCount / Tile::kW), 1, 1>,
|
||||
Shape<1, 1, kThreadCount, 1> >::type Delta;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment of elements to load
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()),
|
||||
strides(make_Coord(2 * (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
kAccessSize,
|
||||
1)) {
|
||||
// strip-mine this tile
|
||||
Coord<4> offset = make_Coord(0, threadIdx.x / Tile::kW, threadIdx.x % Tile::kW, 0);
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &fragment) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(fragment[w + Iterations::kW * h]), pointer, _offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
|
||||
CUTLASS_DEVICE
|
||||
void load_post_increment(Fragment &fragment) { load(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,167 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of WMMA GEMM's epilogue phase.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/gemm/wmma_gemm_global_tile.h"
|
||||
#include "cutlass/gemm/wmma_gemm_shared_tile.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Accumulator_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct WmmaGemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of WMMAs in the H dimension.
|
||||
static int const kWmmasPerH =
|
||||
GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
|
||||
/// The number of iterations in the epilogue. That's the number of "horizontal" WMMAs.
|
||||
typedef Shape<1, 1, kWmmasPerH> Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0, 0, 0> Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef WmmaGemmSharedStoreTileDTraits<
|
||||
// The output layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
WmmaMatrix,
|
||||
FragmentElementType::kWmmaMatrix>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef WmmaGemmSharedLoadTileDTraits<
|
||||
// The pointer.
|
||||
typename Functor::Scalar,
|
||||
// The tile size.
|
||||
typename SharedStoreIteratorD::Tile,
|
||||
// The number of threads.
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// this parameter helps with swizzling when accum is fp32 and output is fp16
|
||||
int(sizeof(Accumulator_)) / int(sizeof(typename GemmConfig_::ScalarD))
|
||||
>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The stream to load D.
|
||||
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,167 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines tile iterator traits for loading thread block-level tile from global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
|
||||
struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Base::Threads::kW;
|
||||
int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index_> {
|
||||
/// This class.
|
||||
typedef WmmaGemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The traits.
|
||||
typedef TileTraits_ Traits;
|
||||
/// The base class.
|
||||
typedef GemmGlobalIteratorCd<Traits, Index_> Base;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset functor.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Base parameters.
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// The params.
|
||||
struct Params : public BaseParams {
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
Index ldm,
|
||||
Index n,
|
||||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
BaseParams::pointer = pointer;
|
||||
// Stride between GEMMs
|
||||
this->stride_d = batch_stride;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
this->stride_h = ldm;
|
||||
// Each thread output 1 column per iteration. .
|
||||
this->inc_h = ldm * TileTraits_::Threads::kH;
|
||||
this->inc_advance = this->inc_h + epilogue_stride_w;
|
||||
|
||||
this->predicate_offset = n;
|
||||
this->predicate_inc_h = TileTraits_::Threads::kH;
|
||||
this->predicate_inc_advance = this->predicate_inc_h + epilogue_delta_w;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int const pointer_offset = 0,
|
||||
int const pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
|
||||
: Base(params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {}
|
||||
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::load_element(value, d, h, w, c);
|
||||
}
|
||||
|
||||
/// Stores a single fragment element into memory
|
||||
CUTLASS_DEVICE void store_element(
|
||||
typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
|
||||
Store<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW>::store(value, Base::params.pointer, offset);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
Base::load_post_increment(fragment);
|
||||
}
|
||||
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store_post_increment(Fragment& fragment) {
|
||||
Base::store_post_increment(fragment);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,367 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayoutA_,
|
||||
typename ScalarA_,
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
typename ScalarB_,
|
||||
MatrixLayout::Kind kLayoutC_,
|
||||
typename ScalarC_,
|
||||
typename WarpGemmShape_,
|
||||
typename InstructionShape_>
|
||||
struct WmmaGemmMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef InstructionShape_ InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with binary operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<128, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<128, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<bin1_t, 32> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<bin1_t, 32> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::bmma_sync(elt_d,
|
||||
elt_a,
|
||||
elt_b,
|
||||
elt_c,
|
||||
nvcuda::wmma::experimental::bmmaBitOpXOR,
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with signed 4-bit integer operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<int4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<int4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<32, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<32, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<int4_t, 8> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<int4_t, 8> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<int4_t, 8>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<int4_t, 8>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with unsigned 4-bit integer operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<32, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<32, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<uint4_t, 8> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<uint4_t, 8> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,239 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterator traits for efficiently loading and storing fragment to and from shared
|
||||
memory, specialized for WMMA GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileATraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp % Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileBTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp / Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename WmmaShape_,
|
||||
int kSkew_ = 0>
|
||||
struct WmmaGemmSharedStoreTileDTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kC;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The pointer.
|
||||
typedef Scalar* Pointer;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// The tile with skew.
|
||||
typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
|
||||
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The starting column.
|
||||
int const h = warp / Warps::kW * WmmaShape::kH;
|
||||
// The w.
|
||||
int const w = warp % Warps::kW * WmmaShape::kW;
|
||||
// The offset.
|
||||
int const offset = h * Tile::kW + w;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_, int kLdsPerAccess_ = 1>
|
||||
struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The tile.
|
||||
typedef typename WmmaReshapeTile<Tile_, kScalarsPerLds_, kLdsPerAccess_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_, kScalarsPerLds_>
|
||||
ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
|
||||
Iterations;
|
||||
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The offset.
|
||||
int const offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user