CUTLASS 2.9 (#468)
This commit is contained in:
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,23 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG]"
|
||||
labels: "? - Needs Triage, bug"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
35
.github/ISSUE_TEMPLATE/documentation_request.md
vendored
@ -1,35 +0,0 @@
|
||||
---
|
||||
name: Documentation request
|
||||
about: Report incorrect or needed documentation to improve CUTLASS
|
||||
title: "[DOC]"
|
||||
labels: "? - Needs Triage, documentation"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Report incorrect documentation
|
||||
|
||||
**Location of incorrect documentation**
|
||||
Provide links and line numbers if applicable.
|
||||
|
||||
**Describe the problems or issues found in the documentation**
|
||||
A clear and concise description of what you found to be incorrect.
|
||||
|
||||
**Steps taken to verify documentation is incorrect**
|
||||
List any steps you have taken:
|
||||
|
||||
**Suggested fix for documentation**
|
||||
Detail proposed changes to fix the documentation if you have any.
|
||||
|
||||
---
|
||||
|
||||
## Report needed documentation
|
||||
|
||||
**Report needed documentation**
|
||||
A clear and concise description of what documentation you believe it is needed and why.
|
||||
|
||||
**Describe the documentation you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Steps taken to search for needed documentation**
|
||||
List any steps you have taken:
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for CUTLASS
|
||||
title: "[FEA]"
|
||||
labels: "? - Needs Triage, feature request"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
10
.github/ISSUE_TEMPLATE/submit_question.md
vendored
@ -1,10 +0,0 @@
|
||||
---
|
||||
name: Submit question
|
||||
about: Ask a general question about CUTLASS
|
||||
title: "[QST]"
|
||||
labels: "? - Needs Triage, question"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What is your question?**
|
||||
11
.github/workflows/labeler.yml
vendored
11
.github/workflows/labeler.yml
vendored
@ -1,11 +0,0 @@
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@main
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
@ -1,35 +0,0 @@
|
||||
name: Auto Assign New Issues to Triage Project
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
assign_one_project:
|
||||
runs-on: ubuntu-latest
|
||||
name: Assign to New Issues to Triage Project
|
||||
steps:
|
||||
- name: Process bug issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'bug') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process feature issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, 'feature request') && contains(github.event.issue.labels.*.name, '? - Needs Triage')
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
- name: Process other issues
|
||||
uses: docker://takanabe/github-actions-automate-projects:v0.0.1
|
||||
if: contains(github.event.issue.labels.*.name, '? - Needs Triage') && (!contains(github.event.issue.labels.*.name, 'bug') && !contains(github.event.issue.labels.*.name, 'feature request'))
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
57
.github/workflows/stale.yml
vendored
57
.github/workflows/stale.yml
vendored
@ -1,57 +0,0 @@
|
||||
name: Mark inactive issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 * * * *"
|
||||
|
||||
jobs:
|
||||
mark-inactive-30d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 30 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
This issue will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-issue-label: "inactive-30d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-30d` due to no recent activity in the past 30 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
This PR will be labeled `inactive-90d` if there is no activity in the next 60 days.
|
||||
stale-pr-label: "inactive-30d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 30
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
mark-inactive-90d:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark 90 day inactive issues and pull requests
|
||||
uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: >
|
||||
This issue has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this issue if no further response or action is needed.
|
||||
Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.
|
||||
stale-issue-label: "inactive-90d"
|
||||
exempt-issue-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-issue-stale: 90
|
||||
days-before-issue-close: -1
|
||||
stale-pr-message: >
|
||||
This PR has been labeled `inactive-90d` due to no recent activity in the past 90 days.
|
||||
Please close this PR if it is no longer required.
|
||||
Otherwise, please respond with a comment indicating any updates.
|
||||
stale-pr-label: "inactive-90d"
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 90
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
64
CHANGELOG.md
64
CHANGELOG.md
@ -1,5 +1,26 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py)
|
||||
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py)
|
||||
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](/example/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
@ -23,7 +44,6 @@
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
@ -210,27 +230,33 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
|
||||
@ -32,7 +38,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.8.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.9.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
@ -83,7 +89,7 @@ set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUT
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
@ -187,11 +193,9 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of opera
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
|
||||
|
||||
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
|
||||
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
@ -203,7 +207,6 @@ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
endif()
|
||||
|
||||
|
||||
set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")
|
||||
|
||||
if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
@ -211,7 +214,6 @@ if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
#
|
||||
@ -748,7 +750,7 @@ if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
45
CUDA.cmake
45
CUDA.cmake
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
@ -213,8 +219,7 @@ function(cutlass_correct_source_file_language_property)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# If building with all kernels, set UNITY build on by default.
|
||||
if (CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
|
||||
|
||||
42
LICENSE.txt
42
LICENSE.txt
@ -1,23 +1,27 @@
|
||||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
91
README.md
91
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.8
|
||||
# CUTLASS 2.9
|
||||
|
||||
_CUTLASS 2.8 - November 2021_
|
||||
_CUTLASS 2.9 - April 2022_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
@ -34,18 +34,23 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.8
|
||||
CUTLASS 2.8 is an update to CUTLASS adding:
|
||||
- [TF32x3:](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm) emulated single-precision using Tensor Cores; 45+ TFLOPs on NVIDIA A100
|
||||
- [Mainloop fusion for Convolution:](/examples/25_ampere_fprop_mainloop_fusion) convolution with fused per-channel bias-add
|
||||
- [Grouped GEMM:](/examples/24_gemm_grouped) similar to batched GEMM with distinct problem size per group
|
||||
- [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing.
|
||||
- Optimal performance using [CUDA 11.5](https://developer.nvidia.com/cuda-downloads)
|
||||
# What's New in CUTLASS 2.9
|
||||
|
||||
CUTLASS 2.9 is an update to CUTLASS adding:
|
||||
- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
- [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and
|
||||
- [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu)
|
||||
- [CUTLASS Python](/example/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
- [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
- Optimal performance using [CUDA 11.6u2](https://developer.nvidia.com/cuda-downloads)
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
|
||||
@ -66,8 +71,8 @@ compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-download
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, and CUDA 11.4.
|
||||
performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -75,8 +80,10 @@ We have tested the following environments.
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| | Microsoft Visual Studio 2019|
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 21.04 | GCC 11.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
@ -84,10 +91,7 @@ See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
For all GPUs, we recommend compiling with the [**CUDA 11.5 Toolkit**](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
@ -97,6 +101,9 @@ for best performance.
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
@ -230,6 +237,16 @@ examples/
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
|
||||
31_basic_syrk # example demonstrating Symetric rank-K update
|
||||
|
||||
32_basic_trmm #
|
||||
|
||||
33_ampere_3xtf32_tensorop_symm #
|
||||
|
||||
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
40_cutlass_py # example demonstrating CUTLASS with CUDA Python
|
||||
```
|
||||
|
||||
### Tools
|
||||
@ -485,27 +502,33 @@ The official list of CUTLASS developers and contributors is available here: [CON
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
42
cmake/nop.cu
42
cmake/nop.cu
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
42
cuBLAS.cmake
42
cuBLAS.cmake
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
43
cuDNN.cmake
43
cuDNN.cmake
@ -1,24 +1,29 @@
|
||||
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(DEFINED CUDNN_ENABLED)
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
00_basic_gemm
|
||||
basic_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
01_cutlass_utilities
|
||||
cutlass_utilities.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
02_dump_reg_shmem
|
||||
dump_reg_shmem.cu
|
||||
|
||||
@ -1,27 +1,31 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
*notice, this list of conditions and the following disclaimer in the
|
||||
*documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its
|
||||
*contributors may be used to endorse or promote products derived from this
|
||||
*software without specific prior written permission.
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
|
||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -55,27 +61,49 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
||||
new VisualizeLayout<cutlass::layout::ColumnMajorInterleaved<4>>},
|
||||
{"RowMajorInterleaved<4>",
|
||||
new VisualizeLayout<cutlass::layout::RowMajorInterleaved<4>>},
|
||||
// All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling
|
||||
// layout implementation with different templates.
|
||||
//
|
||||
// BMMA 88128 Interleaved-256
|
||||
// BMMA 168256 Interleaved-256
|
||||
{"TensorOpMultiplicand<1,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 256>>},
|
||||
// BMMA 88128 TN kblock512
|
||||
// BMMA 168256 TN kblock512
|
||||
{"TensorOpMultiplicand<1,512>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 512>>},
|
||||
// BMMA 168256 TN kblock1024
|
||||
{"TensorOpMultiplicand<1,1024>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 1024>>},
|
||||
// Integer matrix multiply.int4 8832 Interleaved-64
|
||||
// Integer matrix multiply.int4 16864 Interleaved-64
|
||||
{"TensorOpMultiplicand<4,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 64>>},
|
||||
// Integer matrix multiply.int4 8832 TN kblock128
|
||||
// Integer matrix multiply.int4 16864 TN kblock128
|
||||
{"TensorOpMultiplicand<4,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 128>>},
|
||||
// Integer matrix multiply.int4 16864 TN kblock256
|
||||
{"TensorOpMultiplicand<4,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 256>>},
|
||||
// Integer matrix multiply 8816 Interleaved-32
|
||||
// Integer matrix multiply 16832 Interleaved-32
|
||||
{"TensorOpMultiplicand<8,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 32>>},
|
||||
// Integer matrix multiply 8816 TN kblock64
|
||||
// Integer matrix multiply 16832 TN kblock64
|
||||
{"TensorOpMultiplicand<8,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 64>>},
|
||||
// Integer matrix multiply 16832 TN kblock128
|
||||
{"TensorOpMultiplicand<8,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 128>>},
|
||||
// Matrix Multiply 1688 TN kblock32
|
||||
// Matrix multiply 16816 TN kblock32
|
||||
{"TensorOpMultiplicand<16,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 32>>},
|
||||
// Matrix multiply 1688 NT
|
||||
// Matrix multiply 16816 NT
|
||||
// Matrix multiply 16816 TN kblock64
|
||||
{"TensorOpMultiplicand<16,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 64>>},
|
||||
// Matrix multiply 1688.TF32 TN kblock16
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
04_tile_iterator
|
||||
tile_iterator.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
05_batched_gemm
|
||||
batched_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -34,7 +40,7 @@
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute a batched gemm in two different ways:
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways:
|
||||
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
|
||||
matrices of the batch (this is called a strided batched gemm).
|
||||
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
|
||||
@ -231,6 +237,7 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
cudaError_t run_batched_gemm(bool use_array) {
|
||||
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
06_splitK_gemm
|
||||
splitk_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
07_volta_tensorop_gemm
|
||||
volta_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
08_turing_tensorop_gemm
|
||||
turing_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
09_turing_tensorop_conv2dfprop
|
||||
turing_tensorop_conv2dfprop.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -333,21 +339,21 @@ struct Options {
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
10_planar_complex
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -167,15 +173,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
11_planar_complex_array
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -165,15 +171,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n";
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
12_gemm_bias_relu
|
||||
gemm_bias_relu.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -294,4 +300,3 @@ int main() {
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
@ -1,45 +1,82 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_gemms
|
||||
fused_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_convs
|
||||
fused_conv2d.cu
|
||||
)
|
||||
|
||||
|
||||
target_include_directories(
|
||||
13_fused_two_gemms
|
||||
PRIVATE
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
add_custom_target(13_fused_two_gemms)
|
||||
|
||||
add_custom_target(13_fused_two_convs)
|
||||
|
||||
add_custom_target(13_two_tensor_op_fusion
|
||||
DEPENDS 13_fused_two_gemms
|
||||
13_fused_two_convs
|
||||
PRIVATE
|
||||
.
|
||||
)
|
||||
|
||||
foreach(FUSION_CONV_EXAMPLE
|
||||
fused_two_convs_f16_sm75_rf
|
||||
fused_two_convs_f16_sm75_shmem
|
||||
fused_two_convs_f16_sm80_rf
|
||||
fused_two_convs_f16_sm80_shmem
|
||||
fused_two_convs_s8_sm75_rf
|
||||
fused_two_convs_s8_sm75_shmem
|
||||
fused_two_convs_s8_sm80_rf
|
||||
fused_two_convs_s8_sm80_shmem
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_CONV_EXAMPLE}
|
||||
${FUSION_CONV_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
add_dependencies(13_fused_two_convs 13_${FUSION_CONV_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
fused_two_gemms_s8_sm75_shmem
|
||||
fused_two_gemms_s8_sm80_rf
|
||||
fused_two_gemms_s8_sm80_shmem
|
||||
)
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_GEMM_EXAMPLE}
|
||||
${FUSION_GEMM_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -48,33 +48,48 @@ addition to its own input activation tile. Therefore the input activation warp t
|
||||
2nd GEMM/Conv only depends on the output warp accumulator of the 1st GEMM/Conv in the
|
||||
register file, and the operation can be fully register-file-resident.
|
||||
|
||||
On the other hand, this constraint can be relaxed if the output accumulator of the 1st GEMM/CONV
|
||||
is staged in the shared memory and then used as input for the 2nd GEMM/CONV. In this case, the
|
||||
input of each warp tile can be loaded from the shared memory so they do not need to be RF-resident,
|
||||
therefore each warp does not need to store the entire input matrix of 2nd GEMM in its RF. This is
|
||||
illustrated in the diagram below.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_shmem_resident_fusion.png></p>
|
||||
|
||||
|
||||
When applying the above constraint to convolutions, it is required that the 2nd Convolution
|
||||
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
|
||||
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -1,30 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implicit GEMM testbed
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -275,7 +279,7 @@ public:
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
@ -592,7 +596,7 @@ public:
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "time " << conv2dTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -264,7 +270,7 @@ struct B2bNonFusedGemmRun
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@ -507,7 +513,19 @@ struct B2bFusedGemmRun
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -536,7 +554,7 @@ struct B2bFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1.sync_host();
|
||||
|
||||
|
||||
@ -1,30 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implicit GEMM testbed
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -286,7 +290,7 @@ public:
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
@ -617,7 +621,7 @@ public:
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "time " << conv2dTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -275,11 +281,12 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -334,7 +341,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
@ -360,7 +367,6 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@ -531,7 +537,18 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -560,10 +577,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / (float)runs << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -611,7 +629,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
@ -636,7 +654,6 @@ struct B2bInterleavedFusedGemmRun
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -40,6 +46,7 @@
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -102,6 +109,8 @@ template <
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
@ -172,7 +181,8 @@ class B2bGemm {
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator
|
||||
Operator,
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -36,6 +42,10 @@
|
||||
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm80.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
|
||||
@ -1,138 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h"
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75_rf_res();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75_rf_res();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -1,141 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h"
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16();
|
||||
pass &= run_fused_gemm_f16();
|
||||
pass &= run_nonfused_gemm_s8();
|
||||
pass &= run_fused_gemm_s8();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16_sm80();
|
||||
pass &= run_fused_gemm_f16_sm80();
|
||||
pass &= run_nonfused_gemm_s8_sm80();
|
||||
pass &= run_fused_gemm_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
@ -64,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //use beta for bias
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
@ -89,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -138,83 +142,6 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
@ -230,8 +157,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -292,5 +219,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
@ -69,10 +73,10 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -138,8 +142,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
@ -151,10 +154,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -198,7 +201,7 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
@ -208,7 +211,20 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
return true;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
@ -137,81 +141,6 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
@ -228,8 +157,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -289,5 +218,18 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,235 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,30 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -35,24 +38,25 @@
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
@ -69,8 +73,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
@ -137,7 +141,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -152,8 +157,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -199,7 +204,7 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
@ -211,6 +216,18 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,234 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
@ -55,10 +60,10 @@ bool run_nonfused_gemm_f16() {
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -119,7 +124,8 @@ bool run_nonfused_gemm_f16() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16() {
|
||||
|
||||
bool run_fused_gemm_f16_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -130,9 +136,9 @@ bool run_fused_gemm_f16() {
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
@ -153,8 +159,6 @@ bool run_fused_gemm_f16() {
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
@ -178,7 +182,7 @@ bool run_fused_gemm_f16() {
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -187,4 +191,17 @@ bool run_fused_gemm_f16() {
|
||||
|
||||
return passed;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
@ -119,7 +124,7 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80() {
|
||||
bool run_fused_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -178,7 +183,7 @@ bool run_fused_gemm_f16_sm80() {
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -188,4 +193,20 @@ bool run_fused_gemm_f16_sm80() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,214 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
@ -57,8 +62,8 @@ bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -119,7 +124,8 @@ bool run_nonfused_gemm_s8() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8() {
|
||||
|
||||
bool run_fused_gemm_s8_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -130,10 +136,10 @@ bool run_fused_gemm_s8() {
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -176,7 +182,7 @@ bool run_fused_gemm_s8() {
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -186,4 +192,18 @@ bool run_fused_gemm_s8() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm s8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,11 +42,12 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*1600, 128, 64);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
@ -55,10 +60,10 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -128,7 +133,8 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -163,6 +169,8 @@ bool run_fused_gemm_s8_sm80() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
@ -182,6 +190,7 @@ bool run_fused_gemm_s8_sm80() {
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
@ -190,7 +199,7 @@ bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
@ -199,4 +208,19 @@ bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -202,6 +208,19 @@ struct B2bGemm {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// Determine if fusion sizes are valid
|
||||
if(problem_size_0.m() != problem_size_1.m())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() != problem_size_1.k())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,749 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
false
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,740 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,817 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,804 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -111,7 +117,9 @@ template <
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
|
||||
@ -0,0 +1,397 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/default_b2b_mma_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
EpilogueOutputOp0,
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
typename B2bMma::Operator1,
|
||||
kPartitionsK1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
||||
true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue for the 2nd Gemm
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// Run tests on GPUs
|
||||
|
||||
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
|
||||
|
||||
bool supported = false;
|
||||
|
||||
int arch_major = arch / 10;
|
||||
int arch_minor = arch - arch / 10 * 10;
|
||||
|
||||
if(arch_major >= 8) {
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
else if(arch_major >= 7) {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
if (!supported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Device: " << props.name << std::endl;
|
||||
std::cout << "Arch: SM" << arch << std::endl;
|
||||
std::cout << "Test: " << test_name << std::endl;
|
||||
for(auto func : test_funcs) {
|
||||
pass &= func();
|
||||
}
|
||||
|
||||
|
||||
if(pass)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -693,7 +699,7 @@ public:
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
|
||||
@ -0,0 +1,816 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bImplicitGemmMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
using ElementC = typename Policy0::Operator::ElementC;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(
|
||||
IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
|
||||
iterator_A0.set_iteration_index(group_start_A0);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
|
||||
if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0);
|
||||
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
|
||||
iterator_B1.set_iteration_index(group_start_B1);
|
||||
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0) {
|
||||
group_start_iteration_A0 = 0;
|
||||
group_start_iteration_B0 = 0;
|
||||
} else {
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Implicit Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_1(iterator_B1);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_B1;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
group_start_iteration_B1 = 0;
|
||||
} else {
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1,
|
||||
group_start_iteration_B1);
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -67,8 +73,7 @@ template <
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
@ -94,19 +99,19 @@ template <
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A operand
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
@ -396,7 +401,6 @@ public:
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
}
|
||||
@ -452,11 +456,8 @@ public:
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
// int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -477,6 +478,7 @@ public:
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
@ -489,7 +491,8 @@ public:
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1,
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -64,11 +70,11 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaBaseSmemAccumulator :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
@ -146,7 +152,6 @@ class B2bMmaBaseSmemAccumulator :
|
||||
AccumulatorSharedStorage0 accumulator_shared_storage0;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -0,0 +1,860 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
iterator_A0.set_iteration_index(group_start_A0 *
|
||||
IteratorA0::kAccessesPerVector);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0 *
|
||||
IteratorB0::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
iterator_B1.set_iteration_index(group_start_B1 *
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0)
|
||||
{
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
int group_start_iteration_B1;
|
||||
|
||||
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
int group_start_iteration_B1;
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -106,7 +112,8 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
class B2bMmaPipelined :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
@ -174,7 +181,7 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
@ -268,8 +275,8 @@ public:
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
@ -294,23 +301,19 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_0 <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
@ -324,19 +327,14 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
++this->smem_iterator_A_;
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
@ -365,19 +363,18 @@ public:
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_0 <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,7 +396,7 @@ public:
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
@ -409,7 +406,6 @@ public:
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
//warp_tile_iterator_A1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
@ -425,9 +421,7 @@ public:
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_1 <= 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -450,8 +444,7 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
++this->smem_iterator_B1_;
|
||||
@ -475,7 +468,6 @@ public:
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
@ -484,17 +476,14 @@ public:
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
++iterator_B1;
|
||||
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_1 <= 2) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelinedSmemAccumulator :
|
||||
public B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaPipelinedSmemAccumulator(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -89,7 +95,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false>
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Staging the accumulators in shared memory.
|
||||
bool SmemAccumulator = false>
|
||||
struct DefaultB2bMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,605 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
|
||||
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output for multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, false, true> {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA0 = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB0 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for column-major-interleaved output with multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
ampere_tf32_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -132,12 +138,12 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha\n"
|
||||
<< " --beta <f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n"
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
16_ampere_tensorop_conv2dfprop
|
||||
ampere_tensorop_conv2dfprop.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -321,21 +327,21 @@ struct Options {
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
@ -427,7 +433,8 @@ Result profile_convolution(Options const &options) {
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
@ -453,15 +460,20 @@ Result profile_convolution(Options const &options) {
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
tensor_d.host_view());
|
||||
|
||||
// Fill tensor D for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
@ -491,7 +503,7 @@ Result profile_convolution(Options const &options) {
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
@ -542,17 +554,17 @@ Result profile_convolution(Options const &options) {
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_c.host_ref(),
|
||||
tensor_ref_d.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
tensor_d.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
@ -584,10 +596,10 @@ Result profile_convolution(Options const &options) {
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
17_fprop_per_channel_bias
|
||||
fprop_per_channel_bias.cu
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user