Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e45e773436 | |||
| dae6b6893b | |||
| ba18ea9c32 | |||
| 9ab9110168 | |||
| e5d4669f16 | |||
| 94f01f19d5 | |||
| fa56763c25 | |||
| 25e26a6e51 | |||
| f248e9bdb4 | |||
| dceefe4f64 | |||
| c3881d097e | |||
| a29dfb1c63 | |||
| 0abaac84ea | |||
| 858c735856 | |||
| d6f58b2d14 | |||
| c4cf0dad82 | |||
| 57551902d0 | |||
| 1604ebaf10 | |||
| 6023038bae | |||
| ddd8f9cf41 | |||
| ec2b4fd85d | |||
| 86ce09aed1 | |||
| 21c1fa3849 | |||
| 8c339ac039 | |||
| e49f690fd7 | |||
| 96dad61a75 | |||
| cc2ea4c3fc | |||
| a0de301283 | |||
| 319a389f42 | |||
| 71def2f084 | |||
| 70f3ba57f5 | |||
| dd77fadc70 | |||
| be4578d517 | |||
| d7b499deff | |||
| 310ed81ac3 | |||
| 4c0d6e1eb4 | |||
| 167ac54c65 | |||
| 12f4108ac2 | |||
| dd571f0edb | |||
| 6d0d265047 | |||
| f11fa975a5 | |||
| 0e71d9b450 | |||
| eb0d4c9213 | |||
| bc45e2c023 | |||
| 095cbba57c | |||
| 8f1fe7a132 | |||
| 3ab1eacf09 | |||
| cd39c75e25 | |||
| b2e1e97cb1 | |||
| 96a11a1ef3 | |||
| e96f00586c | |||
| 3cfa5db2a2 | |||
| 1db6971a8d | |||
| b954127297 | |||
| d0d941efc7 | |||
| 8a951b2940 | |||
| 1e4703cbab | |||
| c3353add63 | |||
| ac8825b941 | |||
| 8fd94806e5 | |||
| d7c9cbf0b9 | |||
| c2ee13a0fe | |||
| f78994bb40 | |||
| dceabd4c5a | |||
| 86fa1dc30b | |||
| 288af365db | |||
| 0dc3ba60b3 | |||
| ec4f7e5194 | |||
| 4e666e1dfd | |||
| 3799e12f25 | |||
| fc3bc85db8 | |||
| 49c0a58d50 |
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG]"
|
||||
labels: "? - Needs Triage, bug"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
Normal file
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
---
|
||||
name: Documentation request
|
||||
about: Report incorrect or needed documentation to improve CUTLASS
|
||||
title: "[DOC]"
|
||||
labels: "? - Needs Triage, documentation"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Report incorrect documentation
|
||||
|
||||
**Location of incorrect documentation**
|
||||
Provide links and line numbers if applicable.
|
||||
|
||||
**Describe the problems or issues found in the documentation**
|
||||
A clear and concise description of what you found to be incorrect.
|
||||
|
||||
**Steps taken to verify documentation is incorrect**
|
||||
List any steps you have taken:
|
||||
|
||||
**Suggested fix for documentation**
|
||||
Detail proposed changes to fix the documentation if you have any.
|
||||
|
||||
---
|
||||
|
||||
## Report needed documentation
|
||||
|
||||
**Report needed documentation**
|
||||
A clear and concise description of what documentation you believe it is needed and why.
|
||||
|
||||
**Describe the documentation you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Steps taken to search for needed documentation**
|
||||
List any steps you have taken:
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for CUTLASS
|
||||
title: "[FEA]"
|
||||
labels: "? - Needs Triage, feature request"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
Normal file
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
---
|
||||
name: Submit question
|
||||
about: Ask a general question about CUTLASS
|
||||
title: "[QST]"
|
||||
labels: "? - Needs Triage, question"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What is your question?**
|
||||
11
.github/workflows/labeler.yml
vendored
Normal file
11
.github/workflows/labeler.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@main
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
35
.github/workflows/new-issues-to-triage-projects.yml
vendored
Normal file
35
.github/workflows/new-issues-to-triage-projects.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Auto Assign New Issues to Triage Project
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
assign_one_project:
|
||||
runs-on: ubuntu-latest
|
||||
name: Assign to New Issues to Triage Project
|
||||
steps:
|
||||
- name: Process bug issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'bug') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process feature issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'feature request') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process other issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, '? - Needs Triage') && (!contains(github.event.issue.labels.*.name, 'bug') && !contains(github.event.issue.labels.*.name, 'feature request'))
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
57
.github/workflows/stale.yml
vendored
Normal file
57
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
name: Mark inactive issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 * * * *"
|
||||
|
||||
jobs:
|
||||
mark-inactive-30d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 30 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
This issue will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-issue-label: "inactive-30d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
This PR will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-pr-label: "inactive-30d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 30
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
mark-inactive-90d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 90 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
stale-issue-label: "inactive-90d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 90
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
stale-pr-label: "inactive-90d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 90
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
80
CHANGELOG.md
80
CHANGELOG.md
@ -1,5 +1,42 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py)
|
||||
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py)
|
||||
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* It can select random rows in a row major matrix.
|
||||
* It can select random columns in a column major matrix.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* Supported kernels: GEMM and CONV.
|
||||
* Supported types: fp16 and int8.
|
||||
* Supported architectures: Turing and Ampere.
|
||||
* [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
|
||||
* Epilogue enhancement:
|
||||
* Eliminate bank conflicts in int8 tensor core kernels.
|
||||
* Half2 usage if epilogue compute type is fp16.
|
||||
* More activation functions: Silu, Hardswish, Leaky Relu.
|
||||
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
@ -23,7 +60,6 @@
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
@ -210,27 +246,33 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
82
CITATION.cff
Normal file
82
CITATION.cff
Normal file
@ -0,0 +1,82 @@
|
||||
cff-version: 1.2.0
|
||||
title: CUTLASS
|
||||
message: >-
|
||||
If you use this software, please cite using the
|
||||
following metadata.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Andrew
|
||||
email: akerr@nvidia.com
|
||||
family-names: Kerr
|
||||
affiliation: NVIDIA
|
||||
- given-names: Haicheng
|
||||
family-names: Wu
|
||||
affiliation: NVIDIA
|
||||
email: haichengw@nvidia.com
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramini
|
||||
email: prramani@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Duane
|
||||
family-names: Merrill
|
||||
email: dumerrill@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Piotr
|
||||
family-names: Majcher
|
||||
email: pmajcher@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Paul
|
||||
family-names: Springer
|
||||
email: pspringer@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Markus
|
||||
family-names: Hohnerbach
|
||||
affiliation: NVIDIA
|
||||
email: mhohnerbach@nvidia.com
|
||||
- given-names: Jin
|
||||
family-names: Wang
|
||||
email: jinw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||
abstract: >-
|
||||
CUTLASS is a collection of CUDA C++ template
|
||||
abstractions for implementing high-performance
|
||||
matrix-multiplication (GEMM) and related
|
||||
computations at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical
|
||||
decomposition and data movement similar to those
|
||||
used to implement cuBLAS and cuDNN. CUTLASS
|
||||
decomposes these "moving parts" into reusable,
|
||||
modular software components abstracted by C++
|
||||
template classes. These thread-wide, warp-wide,
|
||||
block-wide, and device-wide primitives can be
|
||||
specialized and tuned via custom tiling sizes, data
|
||||
types, and other algorithmic policy. The resulting
|
||||
flexibility simplifies their use as building blocks
|
||||
within custom kernels and applications.
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt
|
||||
version: '2.9'
|
||||
date-released: '2022-04-27'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0"
|
||||
description: The GitHub release URL of tag 2.9.0
|
||||
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
|
||||
@ -32,7 +38,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.7.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.9.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
@ -67,21 +73,23 @@ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.")
|
||||
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
@ -185,11 +193,9 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of opera
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
|
||||
|
||||
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
|
||||
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
@ -197,7 +203,15 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
message(STATUS "Enable caching of reference results in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")
|
||||
|
||||
if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
message(STATUS "Enable rigorous conv problem sizes in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -320,6 +334,11 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
link_libraries(nvidia::cudart)
|
||||
endif()
|
||||
|
||||
# Support for 128-bit integers if using NVIDIA C++ compiler
|
||||
if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ")
|
||||
endif()
|
||||
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
|
||||
# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this
|
||||
# property for CMake 3.18+, so we request the NEW behavior for correct compatibility.
|
||||
@ -736,7 +755,7 @@ if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
45
CUDA.cmake
45
CUDA.cmake
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
@ -213,8 +219,7 @@ function(cutlass_correct_source_file_language_property)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# If building with all kernels, set UNITY build on by default.
|
||||
if (CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
|
||||
|
||||
42
LICENSE.txt
42
LICENSE.txt
@ -1,23 +1,27 @@
|
||||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@ -1,8 +1,22 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2022
|
||||
|
||||
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
|
||||
|
||||
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
|
||||
|
||||
## 2021
|
||||
|
||||
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
|
||||
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
|
||||
## 2020
|
||||
|
||||
- ["Scalable Knowledge Graph Analytics at 136 Petaflop/s"](https://www.computer.org/csdl/proceedings-article/sc/2020/999800a061/1oeORDgCM0g). Ramakrishnan Kannan, Piyush Sao, Hao Lu, Drahomira Herrmannova, Vijay Thakkar, Robert Patton, Richard Vuduc, Thomas Potok. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Accelerating Sparse DNN Models without Hardware-Support via Tile-Wise Sparsity
|
||||
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
|
||||
|
||||
113
README.md
113
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.8
|
||||
# CUTLASS 2.9
|
||||
|
||||
_CUTLASS 2.8 - November 2021_
|
||||
_CUTLASS 2.9 - April 2022_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
@ -34,18 +34,31 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.8
|
||||
CUTLASS 2.8 is an update to CUTLASS adding:
|
||||
- [TF32x3:](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm) emulated single-precision using Tensor Cores; 45+ TFLOPs on NVIDIA A100
|
||||
- [Mainloop fusion for Convolution:](/examples/25_ampere_fprop_mainloop_fusion) convolution with fused per-channel bias-add
|
||||
- [Grouped GEMM:](/examples/24_gemm_grouped) similar to batched GEMM with distinct problem size per group
|
||||
- [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing.
|
||||
- Optimal performance using [CUDA 11.5](https://developer.nvidia.com/cuda-downloads)
|
||||
# What's New in CUTLASS 2.9
|
||||
|
||||
CUTLASS 2.9 is an update to CUTLASS adding:
|
||||
- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
- [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and
|
||||
- [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu)
|
||||
- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
- [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
- [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
- [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. Bias Vector add is also supported in the first GEMM/CONV.
|
||||
- [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
- [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
- [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores.
|
||||
- Epilogue enhancement with performance improvement, more activation functions, and more fusion patterns.
|
||||
- [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix.
|
||||
- Optimal performance using [CUDA 11.7](https://developer.nvidia.com/cuda-downloads)
|
||||
- [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
|
||||
@ -56,15 +69,25 @@ CUTLASS 2.8 is an update to CUTLASS adding:
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
|
||||
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/),
|
||||
an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/),
|
||||
an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/),
|
||||
and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/)
|
||||
compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
as shown in the above figure. Tensor Core operations are still implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, and CUDA 11.4.
|
||||
performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -72,8 +95,10 @@ We have tested the following environments.
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| | Microsoft Visual Studio 2019|
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 21.04 | GCC 11.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
@ -81,10 +106,7 @@ See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
For all GPUs, we recommend compiling with the [**CUDA 11.5 Toolkit**](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
@ -94,6 +116,9 @@ for best performance.
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
@ -227,6 +252,16 @@ examples/
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
|
||||
31_basic_syrk # example demonstrating Symetric rank-K update
|
||||
|
||||
32_basic_trmm #
|
||||
|
||||
33_ampere_3xtf32_tensorop_symm #
|
||||
|
||||
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
40_cutlass_py # example demonstrating CUTLASS with CUDA Python
|
||||
```
|
||||
|
||||
### Tools
|
||||
@ -482,27 +517,33 @@ The official list of CUTLASS developers and contributors is available here: [CON
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
42
cmake/nop.cu
42
cmake/nop.cu
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
42
cuBLAS.cmake
42
cuBLAS.cmake
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
43
cuDNN.cmake
43
cuDNN.cmake
@ -1,24 +1,29 @@
|
||||
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(DEFINED CUDNN_ENABLED)
|
||||
|
||||
1
docs/_config.yml
Normal file
1
docs/_config.yml
Normal file
@ -0,0 +1 @@
|
||||
theme: jekyll-theme-minimal
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
00_basic_gemm
|
||||
basic_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
01_cutlass_utilities
|
||||
cutlass_utilities.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
02_dump_reg_shmem
|
||||
dump_reg_shmem.cu
|
||||
|
||||
@ -1,27 +1,31 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
*modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
*notice, this list of conditions and the following disclaimer in the
|
||||
*documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its
|
||||
*contributors may be used to endorse or promote products derived from this
|
||||
*software without specific prior written permission.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
|
||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -55,27 +61,49 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
||||
new VisualizeLayout<cutlass::layout::ColumnMajorInterleaved<4>>},
|
||||
{"RowMajorInterleaved<4>",
|
||||
new VisualizeLayout<cutlass::layout::RowMajorInterleaved<4>>},
|
||||
// All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling
|
||||
// layout implementation with different templates.
|
||||
//
|
||||
// BMMA 88128 Interleaved-256
|
||||
// BMMA 168256 Interleaved-256
|
||||
{"TensorOpMultiplicand<1,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 256>>},
|
||||
// BMMA 88128 TN kblock512
|
||||
// BMMA 168256 TN kblock512
|
||||
{"TensorOpMultiplicand<1,512>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 512>>},
|
||||
// BMMA 168256 TN kblock1024
|
||||
{"TensorOpMultiplicand<1,1024>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 1024>>},
|
||||
// Integer matrix multiply.int4 8832 Interleaved-64
|
||||
// Integer matrix multiply.int4 16864 Interleaved-64
|
||||
{"TensorOpMultiplicand<4,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 64>>},
|
||||
// Integer matrix multiply.int4 8832 TN kblock128
|
||||
// Integer matrix multiply.int4 16864 TN kblock128
|
||||
{"TensorOpMultiplicand<4,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 128>>},
|
||||
// Integer matrix multiply.int4 16864 TN kblock256
|
||||
{"TensorOpMultiplicand<4,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 256>>},
|
||||
// Integer matrix multiply 8816 Interleaved-32
|
||||
// Integer matrix multiply 16832 Interleaved-32
|
||||
{"TensorOpMultiplicand<8,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 32>>},
|
||||
// Integer matrix multiply 8816 TN kblock64
|
||||
// Integer matrix multiply 16832 TN kblock64
|
||||
{"TensorOpMultiplicand<8,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 64>>},
|
||||
// Integer matrix multiply 16832 TN kblock128
|
||||
{"TensorOpMultiplicand<8,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 128>>},
|
||||
// Matrix Multiply 1688 TN kblock32
|
||||
// Matrix multiply 16816 TN kblock32
|
||||
{"TensorOpMultiplicand<16,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 32>>},
|
||||
// Matrix multiply 1688 NT
|
||||
// Matrix multiply 16816 NT
|
||||
// Matrix multiply 16816 TN kblock64
|
||||
{"TensorOpMultiplicand<16,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 64>>},
|
||||
// Matrix multiply 1688.TF32 TN kblock16
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
04_tile_iterator
|
||||
tile_iterator.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
05_batched_gemm
|
||||
batched_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -28,12 +34,16 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/device/gemm_array.h"
|
||||
#include "cutlass/gemm/device/gemm_batched.h"
|
||||
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm.
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways:
|
||||
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
|
||||
matrices of the batch (this is called a strided batched gemm).
|
||||
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
|
||||
In this example, both A and B matrix are non-transpose and column major matrix
|
||||
batched_C = batched_A x batched_B
|
||||
As an example, matrix C can be seen as
|
||||
@ -89,6 +99,45 @@ The stride (batch_stride_C) between the first element of two batches is k
|
||||
|
||||
*/
|
||||
|
||||
cudaError_t cutlass_array_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
float alpha,
|
||||
float const * const *A,
|
||||
int lda,
|
||||
float const * const *B,
|
||||
int ldb,
|
||||
float * const *C,
|
||||
int ldc,
|
||||
float beta,
|
||||
int batch_count) {
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmArray<
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor
|
||||
>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op({
|
||||
{m, n, k},
|
||||
A, lda,
|
||||
B, ldb,
|
||||
C, ldc,
|
||||
C, ldc,
|
||||
{alpha, beta},
|
||||
batch_count
|
||||
});
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
cudaError_t cutlass_strided_batched_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
@ -188,7 +237,11 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
cudaError_t run_batched_gemm(bool use_array) {
|
||||
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
||||
|
||||
// Arbitrary problem size
|
||||
int const m = 520;
|
||||
@ -293,11 +346,69 @@ int main() {
|
||||
}
|
||||
|
||||
// run cutlass
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
if (use_array) {
|
||||
// allocate the host memory for the pointers to the matrices of the batch
|
||||
std::vector<float*> host_ptr_A(batch_count);
|
||||
std::vector<float*> host_ptr_B(batch_count);
|
||||
std::vector<float*> host_ptr_C(batch_count);
|
||||
|
||||
// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
|
||||
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
|
||||
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
|
||||
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
|
||||
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
|
||||
}
|
||||
|
||||
// allocate the corresponding device memory
|
||||
float const **ptr_A;
|
||||
float const **ptr_B;
|
||||
float **ptr_C;
|
||||
|
||||
result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy the matrix pointers to the device
|
||||
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
} else {
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy device memory to host
|
||||
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
@ -314,7 +425,7 @@ int main() {
|
||||
|
||||
// Expect bit-level accuracy for this simple example
|
||||
if (ref_C != result_C) {
|
||||
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
|
||||
std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
@ -335,9 +446,19 @@ int main() {
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
int main() {
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
for (bool use_array : {false, true}) {
|
||||
result = run_batched_gemm(use_array);
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Exit.
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
06_splitK_gemm
|
||||
splitk_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
07_volta_tensorop_gemm
|
||||
volta_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
08_turing_tensorop_gemm
|
||||
turing_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
09_turing_tensorop_conv2dfprop
|
||||
turing_tensorop_conv2dfprop.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -333,21 +339,21 @@ struct Options {
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
10_planar_complex
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -167,15 +173,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
11_planar_complex_array
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -165,15 +171,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n";
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
12_gemm_bias_relu
|
||||
gemm_bias_relu.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -50,8 +56,10 @@ using ElementOutput = float; // <- data type of elements
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices.
|
||||
// Column Major for Matrix A, B and C.
|
||||
//
|
||||
// Note this example only works for ColumnMajor output because
|
||||
|
||||
// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias.
|
||||
// If the output is row major, the bias has to be per column, i.e. every column has different bias.
|
||||
// Below list some other notices:
|
||||
// 1) we only have row major epilogue.
|
||||
// 2) we swap A and B if the output is column major then we can still use the
|
||||
// row major epilogue.
|
||||
@ -294,4 +302,3 @@ int main() {
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
@ -1,45 +1,82 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_gemms
|
||||
fused_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_convs
|
||||
fused_conv2d.cu
|
||||
)
|
||||
|
||||
|
||||
target_include_directories(
|
||||
13_fused_two_gemms
|
||||
PRIVATE
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
add_custom_target(13_fused_two_gemms)
|
||||
|
||||
add_custom_target(13_fused_two_convs)
|
||||
|
||||
add_custom_target(13_two_tensor_op_fusion
|
||||
DEPENDS 13_fused_two_gemms
|
||||
13_fused_two_convs
|
||||
)
|
||||
|
||||
foreach(FUSION_CONV_EXAMPLE
|
||||
fused_two_convs_f16_sm75_rf
|
||||
fused_two_convs_f16_sm75_shmem
|
||||
fused_two_convs_f16_sm80_rf
|
||||
fused_two_convs_f16_sm80_shmem
|
||||
fused_two_convs_s8_sm75_rf
|
||||
fused_two_convs_s8_sm75_shmem
|
||||
fused_two_convs_s8_sm80_rf
|
||||
fused_two_convs_s8_sm80_shmem
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_CONV_EXAMPLE}
|
||||
${FUSION_CONV_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
13_fused_two_convs
|
||||
PRIVATE
|
||||
.
|
||||
add_dependencies(13_fused_two_convs 13_${FUSION_CONV_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
fused_two_gemms_s8_sm75_shmem
|
||||
fused_two_gemms_s8_sm80_rf
|
||||
fused_two_gemms_s8_sm80_shmem
|
||||
)
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_GEMM_EXAMPLE}
|
||||
${FUSION_GEMM_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -48,33 +48,71 @@ addition to its own input activation tile. Therefore the input activation warp t
|
||||
2nd GEMM/Conv only depends on the output warp accumulator of the 1st GEMM/Conv in the
|
||||
register file, and the operation can be fully register-file-resident.
|
||||
|
||||
On the other hand, this constraint can be relaxed if the output accumulator of the 1st GEMM/CONV
|
||||
is staged in the shared memory and then used as input for the 2nd GEMM/CONV. In this case, the
|
||||
input of each warp tile can be loaded from the shared memory so they do not need to be RF-resident,
|
||||
therefore each warp does not need to store the entire input matrix of 2nd GEMM in its RF. This is
|
||||
illustrated in the diagram below.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_shmem_resident_fusion.png></p>
|
||||
|
||||
|
||||
When applying the above constraint to convolutions, it is required that the 2nd Convolution
|
||||
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
|
||||
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
|
||||
|
||||
# Build and run
|
||||
|
||||
- Run cmake at top-level CUTLASS
|
||||
- `make 13_two_tensor_op_fusion`
|
||||
- Run individual benchmarks
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -1,30 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implicit GEMM testbed
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -50,6 +54,7 @@
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -149,6 +154,7 @@ public:
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,7 +281,7 @@ public:
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
@ -403,6 +409,7 @@ public:
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
@ -483,6 +490,7 @@ public:
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
@ -592,7 +600,7 @@ public:
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "time " << conv2dTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
@ -603,22 +611,35 @@ public:
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
ElementAccumulator(1), // intermediate alpha = 1
|
||||
ElementAccumulator(0) // beta = 0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasConv2d<
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
typename B2bConv2d::LayoutScaleBias
|
||||
>(
|
||||
problem_size_0,
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -38,6 +44,7 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -62,6 +69,7 @@ struct B2bNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -72,9 +80,10 @@ struct B2bNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -91,7 +100,7 @@ struct B2bNonFusedGemmRun
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -100,9 +109,14 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -141,6 +155,10 @@ struct B2bNonFusedGemmRun
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
@ -157,6 +175,10 @@ struct B2bNonFusedGemmRun
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -169,8 +191,10 @@ struct B2bNonFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
@ -184,9 +208,11 @@ struct B2bNonFusedGemmRun
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -199,7 +225,7 @@ struct B2bNonFusedGemmRun
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
@ -208,7 +234,7 @@ struct B2bNonFusedGemmRun
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
@ -235,7 +261,6 @@ struct B2bNonFusedGemmRun
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
@ -250,7 +275,6 @@ struct B2bNonFusedGemmRun
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -264,7 +288,7 @@ struct B2bNonFusedGemmRun
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@ -292,7 +316,7 @@ struct B2bNonFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
@ -306,7 +330,7 @@ struct B2bNonFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
@ -319,7 +343,6 @@ struct B2bNonFusedGemmRun
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
@ -343,13 +366,14 @@ struct B2bNonFusedGemmRun
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@ -366,6 +390,8 @@ struct B2bFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -376,9 +402,12 @@ struct B2bFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -404,9 +433,14 @@ struct B2bFusedGemmRun
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -445,6 +479,21 @@ struct B2bFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -457,6 +506,10 @@ struct B2bFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -469,21 +522,29 @@ struct B2bFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -498,8 +559,10 @@ struct B2bFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
@ -507,7 +570,18 @@ struct B2bFusedGemmRun
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -536,28 +610,49 @@ struct B2bFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
problem_size_0,
|
||||
reference_Z0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
@ -570,18 +665,15 @@ struct B2bFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
@ -592,7 +684,8 @@ struct B2bFusedGemmRun
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
@ -605,12 +698,14 @@ struct B2bFusedGemmRun
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -1,30 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implicit GEMM testbed
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -51,6 +55,7 @@
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -87,14 +92,14 @@ public:
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -286,7 +291,7 @@ public:
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
@ -375,11 +380,13 @@ public:
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
||||
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
@ -417,12 +424,13 @@ public:
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -499,6 +507,7 @@ public:
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
@ -617,7 +626,7 @@ public:
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "time " << conv2dTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
@ -628,23 +637,36 @@ public:
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
ElementAccumulator(1), // intermediate alpha = 1
|
||||
ElementAccumulator(0) // beta = 0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasConv2d<
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
typename B2bConv2d::LayoutScaleBias,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
>(
|
||||
problem_size_0,
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
@ -712,6 +734,7 @@ public:
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
@ -1,28 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -40,6 +45,7 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -62,6 +68,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -72,9 +79,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -91,14 +99,23 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -141,6 +158,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
@ -161,6 +182,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -173,8 +198,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
@ -195,10 +222,12 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -211,7 +240,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
@ -220,7 +249,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
@ -259,8 +288,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
@ -275,7 +303,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@ -303,7 +331,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
@ -316,8 +344,8 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
@ -325,6 +353,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -353,14 +382,15 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@ -377,6 +407,8 @@ struct B2bInterleavedFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -387,9 +419,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -407,13 +442,22 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -431,7 +475,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
@ -456,6 +500,21 @@ struct B2bInterleavedFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -472,6 +531,10 @@ struct B2bInterleavedFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -484,8 +547,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<16>(
|
||||
@ -504,9 +571,13 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -521,17 +592,29 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
1, /*threadblock_swizzle_k_tile*/
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -560,28 +643,49 @@ struct B2bInterleavedFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
problem_size_0,
|
||||
reference_Z0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
@ -594,18 +698,15 @@ struct B2bInterleavedFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
@ -616,7 +717,8 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
@ -630,13 +732,15 @@ struct B2bInterleavedFusedGemmRun
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -40,6 +46,7 @@
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -102,6 +109,8 @@ template <
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
@ -149,6 +158,10 @@ class B2bGemm {
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Derived types
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor;
|
||||
|
||||
/// Define the kernel
|
||||
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
|
||||
ElementA,
|
||||
@ -172,7 +185,8 @@ class B2bGemm {
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator
|
||||
Operator,
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
@ -187,6 +201,8 @@ class B2bGemm {
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
@ -212,6 +228,8 @@ class B2bGemm {
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
@ -226,6 +244,8 @@ class B2bGemm {
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
@ -338,6 +358,8 @@ public:
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_Scale0.non_const_ref(),
|
||||
args.ref_Bias0.non_const_ref(),
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
@ -358,12 +380,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_B1.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D.data());
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
||||
params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data());
|
||||
params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data());
|
||||
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D1.data());
|
||||
params_.output_op_0 = args.epilogue0;
|
||||
params_.output_op_1 = args.epilogue1;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -36,6 +42,10 @@
|
||||
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm80.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
|
||||
@ -1,138 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h"
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75_rf_res();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75_rf_res();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -1,141 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h"
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16();
|
||||
pass &= run_fused_gemm_f16();
|
||||
pass &= run_nonfused_gemm_s8();
|
||||
pass &= run_fused_gemm_s8();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16_sm80();
|
||||
pass &= run_fused_gemm_f16_sm80();
|
||||
pass &= run_nonfused_gemm_s8_sm80();
|
||||
pass &= run_fused_gemm_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
@ -64,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //use beta for bias
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -138,83 +142,6 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
@ -224,14 +151,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -292,5 +220,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,234 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
@ -65,14 +69,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -90,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -114,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -138,8 +143,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
@ -147,14 +151,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -171,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
@ -198,7 +204,7 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
@ -208,7 +214,20 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
return true;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,236 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
238
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu
Normal file
238
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu
Normal file
@ -0,0 +1,238 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
@ -64,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -89,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -113,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -137,7 +142,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -146,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -170,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
@ -213,81 +220,19 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
int main() {
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
|
||||
};
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
return testRun(75, funcs, "conv int8 shmem staging");
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
@ -64,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -89,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -113,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -137,7 +142,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -146,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -170,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
@ -199,7 +207,7 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
@ -211,6 +219,18 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,237 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
210
examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu
Normal file
210
examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu
Normal file
@ -0,0 +1,210 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
bool run_fused_gemm_f16_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
@ -50,13 +55,13 @@ bool run_nonfused_gemm_f16() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
@ -79,7 +84,7 @@ bool run_nonfused_gemm_f16() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -101,7 +106,8 @@ bool run_nonfused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -119,21 +125,22 @@ bool run_nonfused_gemm_f16() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16() {
|
||||
bool run_fused_gemm_f16_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -150,10 +157,12 @@ bool run_fused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
@ -173,12 +182,13 @@ bool run_fused_gemm_f16() {
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -187,4 +197,18 @@ bool run_fused_gemm_f16() {
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
@ -50,15 +55,15 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -79,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -101,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -119,21 +125,22 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80() {
|
||||
bool run_fused_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -150,11 +157,10 @@ bool run_fused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
@ -178,7 +184,7 @@ bool run_fused_gemm_f16_sm80() {
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -188,4 +194,20 @@ bool run_fused_gemm_f16_sm80() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,217 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
@ -50,15 +55,15 @@ bool run_nonfused_gemm_s8() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -79,7 +84,7 @@ bool run_nonfused_gemm_s8() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -101,7 +106,8 @@ bool run_nonfused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -119,21 +125,23 @@ bool run_nonfused_gemm_s8() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8() {
|
||||
|
||||
bool run_fused_gemm_s8_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -150,7 +158,8 @@ bool run_fused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
@ -176,7 +185,7 @@ bool run_fused_gemm_s8() {
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -186,4 +195,18 @@ bool run_fused_gemm_s8() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,214 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
@ -50,10 +55,10 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
@ -79,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -106,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -128,21 +133,23 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -160,9 +167,11 @@ bool run_fused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
@ -182,6 +191,7 @@ bool run_fused_gemm_s8_sm80() {
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
@ -190,7 +200,7 @@ bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -199,4 +209,19 @@ bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,226 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -73,6 +79,8 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::Params params_B1;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
@ -103,6 +111,8 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
@ -120,6 +130,8 @@ struct B2bGemm {
|
||||
ref_B0(ref_B0),
|
||||
params_C0(ref_C0.layout()),
|
||||
ref_C0(ref_C0),
|
||||
ref_Scale0(ref_Scale0),
|
||||
ref_Bias0(ref_Bias0),
|
||||
params_B1(ref_B1.layout()),
|
||||
ref_B1(ref_B1),
|
||||
params_C1(ref_C1.layout()),
|
||||
@ -202,6 +214,19 @@ struct B2bGemm {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// Determine if fusion sizes are valid
|
||||
if(problem_size_0.m() != problem_size_1.m())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() != problem_size_1.k())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -286,6 +311,29 @@ struct B2bGemm {
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
@ -303,7 +351,8 @@ struct B2bGemm {
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,749 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
false
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,740 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,817 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,804 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -111,7 +117,9 @@ template <
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
|
||||
@ -0,0 +1,397 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/default_b2b_mma_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
EpilogueOutputOp0,
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
typename B2bMma::Operator1,
|
||||
kPartitionsK1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
||||
true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue for the 2nd Gemm
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -0,0 +1,275 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
||||
in this header are not specialized for any particular data layout and are therefore not
|
||||
intended to offer the best possible performance. Rather, they are intended to be generic
|
||||
reference implementations to support the CUTLASS unit tests.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Cutlass includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace kernel {
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename OutputTile,
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>
|
||||
>
|
||||
__global__ void TensorScaleBiasGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
|
||||
MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
|
||||
);
|
||||
|
||||
// Update the output tensor
|
||||
for (int j = 0; j < OutputTile::kRow; ++j) {
|
||||
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
||||
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension
|
||||
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
||||
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
||||
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
||||
>
|
||||
__global__ void TensorScaleBiasConv2d(
|
||||
conv::Conv2dProblemSize problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
||||
|
||||
int thread_n[kThreadM];
|
||||
int thread_p[kThreadM];
|
||||
int thread_q[kThreadM];
|
||||
|
||||
// Compute N, P, Q coordinates for each row of a thread's tile
|
||||
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < kThreadM; ++m) {
|
||||
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
}
|
||||
|
||||
// Write out the results
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < kThreadM; ++m) {
|
||||
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < kThreadN; ++n) {
|
||||
int thread_k = k_start + n;
|
||||
if (thread_k < problem_size.K) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
using OutputTile = MatrixShape<4, 4>;
|
||||
|
||||
dim3 block(16, 8);
|
||||
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
||||
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemm<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
OutputTile,
|
||||
ConvertOp
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasConv2d(
|
||||
conv::Conv2dProblemSize problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
||||
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
||||
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
||||
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
||||
|
||||
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
|
||||
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
||||
|
||||
dim3 block(kCtaShapeM, kCtaShapeN);
|
||||
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
||||
|
||||
|
||||
kernel::TensorScaleBiasConv2d<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kThreadM,
|
||||
kThreadN,
|
||||
kCtaShapeM,
|
||||
kCtaShapeN
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// Run tests on GPUs
|
||||
|
||||
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
|
||||
|
||||
bool supported = false;
|
||||
|
||||
int arch_major = arch / 10;
|
||||
int arch_minor = arch - arch / 10 * 10;
|
||||
|
||||
if(arch_major >= 8) {
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
else if(arch_major >= 7) {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
if (!supported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Device: " << props.name << std::endl;
|
||||
std::cout << "Arch: SM" << arch << std::endl;
|
||||
std::cout << "Test: " << test_name << std::endl;
|
||||
for(auto func : test_funcs) {
|
||||
pass &= func();
|
||||
}
|
||||
|
||||
|
||||
if(pass)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -693,7 +699,7 @@ public:
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
@ -739,7 +745,6 @@ public:
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
|
||||
@ -0,0 +1,816 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bImplicitGemmMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
using ElementC = typename Policy0::Operator::ElementC;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(
|
||||
IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
|
||||
iterator_A0.set_iteration_index(group_start_A0);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
|
||||
if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0);
|
||||
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
|
||||
iterator_B1.set_iteration_index(group_start_B1);
|
||||
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0) {
|
||||
group_start_iteration_A0 = 0;
|
||||
group_start_iteration_B0 = 0;
|
||||
} else {
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Implicit Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_1(iterator_B1);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_B1;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
group_start_iteration_B1 = 0;
|
||||
} else {
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1,
|
||||
group_start_iteration_B1);
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -67,8 +73,7 @@ template <
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
@ -94,19 +99,19 @@ template <
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A operand
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
@ -396,7 +401,6 @@ public:
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
}
|
||||
@ -452,11 +456,8 @@ public:
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
// int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -477,6 +478,7 @@ public:
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
@ -489,7 +491,8 @@ public:
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1,
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
@ -501,7 +504,7 @@ public:
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -64,11 +70,11 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaBaseSmemAccumulator :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
@ -146,7 +152,6 @@ class B2bMmaBaseSmemAccumulator :
|
||||
AccumulatorSharedStorage0 accumulator_shared_storage0;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -76,6 +82,11 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -120,6 +131,10 @@ public:
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over intermediate accumulator tile
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
@ -134,6 +149,9 @@ public:
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
@ -148,6 +166,9 @@ public:
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
@ -211,6 +232,8 @@ public:
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
@ -375,11 +398,15 @@ public:
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
@ -617,6 +644,20 @@ public:
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
@ -672,18 +713,29 @@ public:
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0],
|
||||
warp_loaded_frag_A1_scale[0],
|
||||
warp_loaded_frag_A1_bias[0],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
@ -711,15 +763,37 @@ public:
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load threadblock-level scale/bias vector from global memory
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
// Load warp-level scale bias fragment from threadblock scale/bias vector
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
|
||||
@ -0,0 +1,867 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
iterator_A0.set_iteration_index(group_start_A0 *
|
||||
IteratorA0::kAccessesPerVector);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0 *
|
||||
IteratorB0::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
iterator_B1.set_iteration_index(group_start_B1 *
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0)
|
||||
{
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
int group_start_iteration_B1;
|
||||
|
||||
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
int group_start_iteration_B1;
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -70,6 +76,11 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// FragmentIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
@ -106,7 +117,8 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
class B2bMmaPipelined :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
@ -122,6 +134,9 @@ public:
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using FragmentIteratorA1ScaleBias =
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||
|
||||
@ -130,9 +145,12 @@ public:
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
@ -152,6 +170,9 @@ public:
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
@ -161,7 +182,7 @@ public:
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
@ -174,7 +195,7 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
@ -183,6 +204,9 @@ private:
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
/// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment
|
||||
using WarpFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
@ -237,16 +261,18 @@ public:
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
@ -268,8 +294,8 @@ public:
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
@ -294,23 +320,19 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_0 <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
@ -324,19 +346,14 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
++this->smem_iterator_A_;
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
@ -365,19 +382,18 @@ public:
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_0 <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -390,32 +406,53 @@ public:
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
if(PerChannelScale)
|
||||
tb_frag_A1_scale.clear();
|
||||
tb_frag_A1_bias.clear();
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
if(PerChannelScale)
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
if(PerChannelScale)
|
||||
++iterator_A1_scale;
|
||||
++iterator_A1_bias;
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_scale[2];
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_bias[2];
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
//warp_tile_iterator_A1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0],
|
||||
warp_frag_A1_bias[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
@ -425,9 +462,7 @@ public:
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_1 <= 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -450,8 +485,7 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
++this->smem_iterator_B1_;
|
||||
@ -468,14 +502,31 @@ public:
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
@ -484,17 +535,14 @@ public:
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
++iterator_B1;
|
||||
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_1 <= 2) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,543 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelinedSmemAccumulator :
|
||||
public B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaPipelinedSmemAccumulator(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -34,6 +40,10 @@
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
@ -89,7 +99,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false>
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Staging the accumulators in shared memory.
|
||||
bool SmemAccumulator = false>
|
||||
struct DefaultB2bMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -162,6 +174,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
@ -173,6 +201,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
@ -268,6 +297,24 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
@ -282,6 +329,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
@ -369,6 +417,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
@ -376,12 +440,12 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
@ -471,6 +535,23 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
@ -486,6 +567,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
|
||||
@ -0,0 +1,605 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
|
||||
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output for multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, false, true> {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA0 = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB0 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for column-major-interleaved output with multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
ampere_tf32_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -132,12 +138,12 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha\n"
|
||||
<< " --beta <f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n"
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user