v3.9 (#2185)
* v3.8 update x * fix blackwell gg * doc change * doc change * doc change --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
@ -115,4 +115,3 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -2,3 +2,35 @@
|
||||
|
||||
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
|
||||
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -165,3 +165,35 @@ Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -41,3 +41,35 @@ We are currently optimizing the following cases:
|
||||
* Optimizations for memory bound cases.
|
||||
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -207,3 +207,35 @@ With this in mind, this example kernel has the following limitations:
|
||||
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
|
||||
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
|
||||
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -26,11 +26,13 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
63_hopper_gemm_with_weight_prefetch.cu
|
||||
)
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_PREFETCH_CASE
|
||||
)
|
||||
|
||||
target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .)
|
||||
|
||||
@ -74,9 +74,40 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
|
||||
However, note that the example still runs a single GEMM, and most of the performance improvement
|
||||
is expected in end to end applications.
|
||||
|
||||
|
||||
## Limitations
|
||||
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
|
||||
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
|
||||
memory barrier before issuing every single TMA load, and in many cases this will slow down
|
||||
prefetching to the point of being almost ineffective.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -362,11 +362,11 @@ public:
|
||||
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
||||
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z,
|
||||
/*reserved_named_barriers_*/ 14);
|
||||
/*id*/ 0);
|
||||
// Prefetcher warp doesn't arrive on this barrier.
|
||||
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
||||
/*reserved_named_barriers_*/ 15);
|
||||
/*id*/ 1);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
__syncwarp();
|
||||
|
||||
@ -62,3 +62,36 @@ procedure is the same, simply modify the following line in the example:
|
||||
```cpp
|
||||
using TP = _8;
|
||||
```
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -84,3 +84,35 @@ GPU5 OK OK OK OK OK X OK OK
|
||||
GPU6 OK OK OK OK OK OK X OK
|
||||
GPU7 OK OK OK OK OK OK OK X
|
||||
```
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -251,93 +251,93 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
@ -438,14 +438,18 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,10 +521,9 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
// Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto blockscale_m = ceil_div(options.m, get<0>(TileShape{}));
|
||||
auto blockscale_n = ceil_div(options.n, get<1>(TileShape{}));
|
||||
auto blockscale_k = ceil_div(options.k, get<2>(TileShape{}));
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
@ -608,29 +611,40 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
}
|
||||
|
||||
@ -671,10 +685,9 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
@ -707,7 +720,7 @@ int run(Options<RasterOrderOptions> &options)
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return result.passed;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -753,7 +766,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
bool passed = run<Gemm>(options);
|
||||
if (!passed)
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -303,93 +303,93 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename GroupScaleConfig>
|
||||
@ -403,11 +403,9 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
assert(options.n % ScaleGranularityN == 0);
|
||||
|
||||
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = ceil_div(options.m, ScaleGranularityM);
|
||||
auto groupscale_n = ceil_div(options.n, ScaleGranularityN);
|
||||
auto blockscale_k = ceil_div(options.k, cute::get<2>(TileShape{}));
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
@ -582,13 +580,11 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;
|
||||
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
|
||||
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;
|
||||
auto blockscale_m = ceil_div(options.m, get<0>(TileShape_{}));
|
||||
auto blockscale_n = ceil_div(options.n, get<1>(TileShape_{}));
|
||||
auto blockscale_k = ceil_div(options.k, get<2>(TileShape_{}));
|
||||
auto groupscale_m = ceil_div(options.m, ScaleGranularityM);
|
||||
auto groupscale_n = ceil_div(options.n, ScaleGranularityN);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
@ -676,8 +672,13 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
@ -690,15 +691,21 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
}
|
||||
|
||||
@ -716,29 +723,29 @@ int run(Options<RasterOrderOptions> &options)
|
||||
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
bool skip = false;
|
||||
|
||||
if (options.m % ScaleGranularityM != 0) {
|
||||
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n % ScaleGranularityN != 0) {
|
||||
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k % size<2>(TileShape{}) != 0) {
|
||||
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << "Running: " << std::endl;
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
|
||||
if (skip) return -1;
|
||||
|
||||
if (options.m < ScaleGranularityM) {
|
||||
std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n < ScaleGranularityN) {
|
||||
std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k < size<2>(TileShape{})) {
|
||||
std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << " Running... " << std::endl;
|
||||
else return -1;
|
||||
|
||||
initialize<GroupScaleConfig>(options);
|
||||
|
||||
@ -770,17 +777,17 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
@ -806,7 +813,7 @@ int run(Options<RasterOrderOptions> &options)
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return result.passed;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -852,27 +859,31 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
bool passed = true;
|
||||
std::cout << "Basic split-K GEMM kernel" << std::endl;
|
||||
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
|
||||
passed &= run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
|
||||
passed &= run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
|
||||
passed &= run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
|
||||
passed &= run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "StreamK GEMM kernel" << std::endl;
|
||||
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
|
||||
passed &= run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
|
||||
passed &= run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
|
||||
passed &= run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
|
||||
passed &= run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
|
||||
std::cout << std::endl;
|
||||
|
||||
if (!passed)
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -46,6 +46,8 @@ struct Options {
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
float epsilon = 0.02f;
|
||||
float non_zero_floor = 1.f;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -73,6 +75,8 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("verify", verify);
|
||||
cmd.get_cmd_line_argument("epsilon", epsilon);
|
||||
cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
@ -113,7 +117,10 @@ struct Options {
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --verify=<bool> Verify the results.\n\n"
|
||||
<< " --epsilon=<float> The epsilon value for comparing the results.\n\n"
|
||||
<< " --non-zero-floor=<float> The none zero floor for comparing the results.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
|
||||
@ -221,9 +221,9 @@ void gett_mainloop(
|
||||
const int N = cute::size<0>(mainloop_params.B.layout());
|
||||
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
|
||||
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
|
||||
assert(ScaleGranularityM && M % ScaleGranularityM == 0
|
||||
assert(ScaleGranularityM && M % ScaleGranularityM == 0
|
||||
&& "ScaleGranularityM must divide M");
|
||||
assert(ScaleGranularityN && N % ScaleGranularityN == 0
|
||||
assert(ScaleGranularityN && N % ScaleGranularityN == 0
|
||||
&& "ScaleGranularityN must divide N");
|
||||
|
||||
cute::Tensor blockscale_A = domain_offset(
|
||||
|
||||
@ -12,3 +12,35 @@ Note that in Example 55, the argument `--g` is used to determine the block scale
|
||||
## Upcoming features
|
||||
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -194,12 +194,14 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -217,6 +219,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -231,6 +234,7 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
@ -331,6 +335,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -231,6 +231,7 @@ struct Options {
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
int swizzle = 0;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -256,6 +257,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -271,6 +273,7 @@ struct Options {
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --scale_a=<f32> Scaling factor for A\n"
|
||||
<< " --scale_b=<f32> Scaling factor for B\n"
|
||||
<< " --scale_c=<f32> Scaling factor for C\n"
|
||||
@ -476,6 +479,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -28,14 +28,29 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
set(TEST_SWIZZLE_1 --swizzle=1)
|
||||
set(TEST_SWIZZLE_2 --swizzle=2)
|
||||
set(TEST_SWIZZLE_5 --swizzle=5)
|
||||
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
)
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp8_gemm
|
||||
70_blackwell_fp8_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -74,12 +74,14 @@ struct Options {
|
||||
|
||||
int m, n, k, l;
|
||||
float alpha, beta;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1.f), beta(0.f)
|
||||
alpha(1.f), beta(0.f),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -97,6 +99,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -112,7 +115,8 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n";
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -352,6 +356,8 @@ struct ExampleRunner {
|
||||
hw_info
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
// See example 48 for details on custom EVT construction
|
||||
if constexpr (UseCustomEVT) {
|
||||
arguments.epilogue.thread =
|
||||
|
||||
@ -211,12 +211,14 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -234,6 +236,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -247,7 +250,8 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -333,7 +337,7 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -344,8 +348,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
@ -387,6 +391,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
@ -240,12 +240,14 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -263,6 +265,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -276,7 +279,8 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -362,9 +366,9 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -375,8 +379,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
@ -432,6 +436,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -212,12 +212,14 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -235,6 +237,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -248,7 +251,8 @@ struct Options {
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
@ -334,7 +338,7 @@ bool initialize_block(
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
@ -345,8 +349,8 @@ void initialize(const Options &options) {
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
@ -388,6 +392,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -214,7 +214,8 @@ struct Options {
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096), n(4096), k(4096),
|
||||
@ -223,7 +224,8 @@ struct Options {
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1)
|
||||
fallback_cluster_n(1),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -245,6 +247,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
|
||||
if (!validate_cluster_shape()){
|
||||
std::cout << "--Invalid cluster shapes" << std::endl;
|
||||
@ -265,6 +268,7 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
@ -384,7 +388,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
@ -242,6 +242,7 @@ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTil
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
@ -264,6 +265,9 @@ struct Options {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
@ -387,7 +391,8 @@ struct Options {
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n";
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
@ -711,7 +716,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
@ -730,7 +735,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
|
||||
@ -219,14 +219,14 @@ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
|
||||
OutputSFVectorSize,
|
||||
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
|
||||
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
|
||||
>;
|
||||
using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF;
|
||||
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<StrideA> stride_A_host;
|
||||
@ -305,6 +305,7 @@ struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verification = true;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
@ -328,9 +329,12 @@ struct Options {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("no-verif")) {
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
@ -457,7 +461,8 @@ struct Options {
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --no-verif Do not run (host-side) verification kernels\n";
|
||||
<< " --no_verif Do not run (host-side) verification kernels\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
@ -554,9 +559,9 @@ void allocate(const Options &options) {
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
stride_A_host.push_back(stride_A);
|
||||
stride_B_host.push_back(stride_B);
|
||||
@ -775,9 +780,9 @@ bool verify(const Options &options) {
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
|
||||
@ -845,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@ -870,7 +875,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
|
||||
@ -21,3 +21,35 @@ To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easies
|
||||
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
|
||||
It is well-suited for applying masks or activations.
|
||||
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -0,0 +1,546 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
|
||||
NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA.
|
||||
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
This kernel leverages:
|
||||
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Block Scaled Tensor Core MMA Instructions
|
||||
4. Epilogue Optimization
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "79a_blackwell_geforce_nvfp4_bf16_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,593 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
The kernel outputs quantized fp4 values with scale factors that will be the input of another GEMM.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
|
||||
|
||||
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Block Scaled Tensor Core MMA Instructions
|
||||
4. Epilogue Optimization
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
|
||||
using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFD matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
|
||||
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int InputSFVectorSize = 16;
|
||||
constexpr int OutputSFVectorSize = InputSFVectorSize;
|
||||
|
||||
// D = alpha * acc + beta * C
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong // Ping-pong kernel schedule policy.
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
// Matrix-wide normalization constant
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "79b_blackwell_geforce_nvfp4_nvfp4_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (IsBlockScaleSupported) {
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
auto tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<OutputSFVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,546 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
|
||||
MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA.
|
||||
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
|
||||
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Block Scaled Tensor Core MMA Instructions
|
||||
4. Epilogue Optimization
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::mx_float6_t<cutlass::float_e3m2_t>; // Element type for B matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
ThreadBlockShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "79c_blackwell_geforce_mixed_mxfp8_bf16_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
47
examples/79_blackwell_geforce_gemm/CMakeLists.txt
Normal file
47
examples/79_blackwell_geforce_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,47 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
|
||||
cutlass_example_add_executable(
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
79b_blackwell_geforce_nvfp4_nvfp4_gemm
|
||||
79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm
|
||||
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -216,7 +216,7 @@ struct Options {
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "112_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -157,6 +157,7 @@ foreach(EXAMPLE
|
||||
76_blackwell_conv
|
||||
77_blackwell_fmha
|
||||
78_blackwell_emulated_bf16x9_gemm
|
||||
79_blackwell_geforce_gemm
|
||||
81_blackwell_gemm_blockwise
|
||||
)
|
||||
|
||||
|
||||
@ -282,6 +282,10 @@
|
||||
|
||||
Blackwell SM100 FastFP32 (using BF16 to emulate SGEMM) kernel
|
||||
|
||||
* [79_blackwell_geforce_gemm](79_blackwell_geforce_gemm/)
|
||||
|
||||
Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores
|
||||
|
||||
# CuTe - Programming Examples
|
||||
|
||||
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
|
||||
@ -291,3 +295,35 @@ Additionally, CuTe's core layout and layout algebra have their own test cases wi
|
||||
# Python Interface Examples
|
||||
|
||||
Examples leveraging CUTLASS's [Python interface](../python/README.md) are located in [cutlass/examples/python](python/).
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -58,7 +58,7 @@ struct IndexedGather
|
||||
operator()(I i) const { return indices_[i]; }
|
||||
|
||||
CUTE_HOST_DEVICE friend
|
||||
void
|
||||
void
|
||||
print(IndexedGather const &s) {
|
||||
cute::print("Indexed");
|
||||
}
|
||||
@ -80,7 +80,7 @@ struct StridedGather
|
||||
operator()(I i) const { return i * stride_; }
|
||||
|
||||
CUTE_HOST_DEVICE friend
|
||||
void
|
||||
void
|
||||
print(StridedGather const &s) {
|
||||
cute::print("Strided{");
|
||||
print(s.stride_);
|
||||
@ -153,7 +153,7 @@ make_custom_stride_layout(Stride const &stride, Func&& func)
|
||||
/// Helper function to optionally create a gather tensor
|
||||
template<class Iterator, class Shape, class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE
|
||||
auto
|
||||
auto
|
||||
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
|
||||
{
|
||||
if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
|
||||
@ -180,7 +180,7 @@ upcast(Shape const& shape, Stride const& stride)
|
||||
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
|
||||
} else if constexpr (is_scaled_basis<Stride>::value) {
|
||||
if constexpr (Stride::mode() == I) {
|
||||
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
|
||||
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
|
||||
} else {
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
|
||||
@ -27,34 +27,31 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(hopper)
|
||||
add_subdirectory(blackwell)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_1
|
||||
cute_tutorial_sgemm_1
|
||||
sgemm_1.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_2
|
||||
cute_tutorial_sgemm_2
|
||||
sgemm_2.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_sm70
|
||||
cute_tutorial_sgemm_sm70
|
||||
sgemm_sm70.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_sm80
|
||||
cute_tutorial_sgemm_sm80
|
||||
sgemm_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
tiled_copy
|
||||
cute_tutorial_tiled_copy
|
||||
tiled_copy.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
wgmma_sm90
|
||||
wgmma_sm90.cu
|
||||
)
|
||||
|
||||
|
||||
592
examples/cute/tutorial/blackwell/01_mma_sm100.cu
Normal file
592
examples/cute/tutorial/blackwell/01_mma_sm100.cu
Normal file
@ -0,0 +1,592 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CuTe Tutorial for SM100 Programming
|
||||
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
|
||||
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
|
||||
//
|
||||
// The tutorial series is split into five stages:
|
||||
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
|
||||
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
|
||||
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
|
||||
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
|
||||
// Use Thrust to handle host/device allocations
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/half.h> // F16 data type
|
||||
#include <cutlass/util/print_error.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/cluster_launch.hpp>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
|
||||
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
|
||||
// Tutorial helpers
|
||||
#include "example_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tutorial 01: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The goal of this tutorial is to show the CuTe interface for tcgen05.mma and tcgen05.ld operations.
|
||||
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
|
||||
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
|
||||
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
|
||||
// - Matrices C and D are MxN, N-major (BLAS row-major)
|
||||
//
|
||||
// This GEMM kernel performs the following steps:
|
||||
// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) for one MmaTile
|
||||
// using auto-vectorizing copy operations.
|
||||
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
||||
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
|
||||
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
|
||||
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
|
||||
//
|
||||
// SM100 tcgen05.mma instructions operate as follows:
|
||||
// - Read matrix A from SMEM or TMEM
|
||||
// - Read matrix B from SMEM
|
||||
// - Write accumulator to TMEM
|
||||
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
//
|
||||
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
|
||||
// and the MMA's M and N dimensions.
|
||||
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
|
||||
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
|
||||
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
|
||||
//
|
||||
// The MMA details:
|
||||
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
|
||||
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
|
||||
// This example uses F16xF16 = F32 MMA where:
|
||||
// TypeA = cutlass::half_t; // MMA A Data Type
|
||||
// TypeB = cutlass::half_t; // MMA B Data Type
|
||||
// TypeC = float; // MMA C Data Type
|
||||
// TypeD = float; // MMA D Data Type
|
||||
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// The shared memory buffers for A and B matrices.
|
||||
template <class TypeA, // Tensor A data type
|
||||
class TypeB, // Tensor B data type
|
||||
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
|
||||
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
|
||||
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
|
||||
|
||||
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
|
||||
|
||||
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
|
||||
};
|
||||
|
||||
// The device kernel
|
||||
template <class SharedStorage,
|
||||
class ATensor, class BTensor, class CTensor, class DTensor,
|
||||
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
void
|
||||
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
BTensor mB, // (Gemm_N, Gemm_K)
|
||||
CTensor mC, // (Gemm_M, Gemm_N)
|
||||
DTensor mD, // (Gemm_M, Gemm_N)
|
||||
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
|
||||
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
|
||||
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Step 1: The Prologue.
|
||||
|
||||
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename TiledMMA::AtomThrID{}));
|
||||
|
||||
// Construct the MMA grid coordinate from the CTA grid coordinate
|
||||
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
|
||||
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
|
||||
blockIdx.y, // MMA-N coordinate
|
||||
_); // MMA-K coordinate
|
||||
|
||||
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
|
||||
// by this mma tile.
|
||||
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
|
||||
// * Tensor to partition
|
||||
// * Tiler to use for partitioning
|
||||
// * Coordinate to use for slicing the partitioned tensor
|
||||
// * Projection to ignore unwanted modes of the Tiler and Coordinate
|
||||
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
|
||||
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
|
||||
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
|
||||
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("mA:\t"); print(mA); print("\n"); // mA: gmem_ptr[16b](GMEM_ADDR_A) o (512,256):(256,_1)
|
||||
print("mB:\t"); print(mB); print("\n"); // mB: gmem_ptr[16b](GMEM_ADDR_B) o (1024,256):(256,_1)
|
||||
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
|
||||
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
|
||||
|
||||
print("gA:\t"); print(gA); print("\n"); // gA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile) o (_128,_64,4):(256,_1,_64)
|
||||
print("gB:\t"); print(gB); print("\n"); // gB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile) o (_256,_64,4):(_1,256,16384)
|
||||
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
} __syncthreads();
|
||||
|
||||
// The SMEM tensors
|
||||
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Represent the SMEM buffers for A and B
|
||||
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
//
|
||||
// Mma partitioning for A and B
|
||||
//
|
||||
// Note: Partitioned tensors use tXgY naming convention:
|
||||
// tXgY -> The partitioning pattern tX applied to tensor gY
|
||||
|
||||
auto mma_v = get<0>(mma_coord_vmnk);
|
||||
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
|
||||
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
|
||||
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
|
||||
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile + offset_for_mma) o ((_128,_16),_1,_4,4):((256,_1),_0,_16,_64)
|
||||
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile + offset_for_mma) o ((_256,_16),_1,_4,4):((_1,256),_0,4096,16384)
|
||||
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// MMA Fragment Allocation
|
||||
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
|
||||
// For tcgen05.mma operations:
|
||||
// - Matrices A and B are sourced from SMEM
|
||||
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
|
||||
// - The first mode of each descriptor represents the SMEM for a single MMA operation
|
||||
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
// TMEM Allocation
|
||||
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
|
||||
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
|
||||
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// Barrier Initialization
|
||||
uint32_t elect_one_thr = cute::elect_one_sync();
|
||||
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
|
||||
|
||||
// Barriers in SMEM initialized by a single thread.
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1);
|
||||
}
|
||||
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
__syncthreads(); // Make sure all threads observe barrier initialization.
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
|
||||
{
|
||||
// Step 2a: Load A and B tiles
|
||||
|
||||
// Using auto-vectorized copy operation:
|
||||
// - Utilizes 128 threads for parallel data transfer
|
||||
// - Copy operations are distributed efficiently across all threads
|
||||
// - CuTe can automatically determine optimal vector width
|
||||
cooperative_copy<128>(threadIdx.x, tCgA(_,_,_,k_tile), tCsA); // Load MmaTile_M x MmaTile_K A tile
|
||||
cooperative_copy<128>(threadIdx.x, tCgB(_,_,_,k_tile), tCsB); // Load MmaTile_N x MmaTile_K B tile
|
||||
|
||||
// Step 2b: Execute the MMAs for this tile
|
||||
|
||||
// Wait for loads to SMEM to complete with __syncthreads()
|
||||
__syncthreads();
|
||||
|
||||
// tcgen05.mma instructions require single-thread execution:
|
||||
// - Only one warp performs the MMA-related loop operations
|
||||
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
|
||||
// - No explicit elect_one_sync region is needed from the user
|
||||
if (elect_one_warp) {
|
||||
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
|
||||
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
|
||||
}
|
||||
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
|
||||
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
|
||||
mma_barrier_phase_bit ^= 1;
|
||||
}
|
||||
|
||||
// Step 3: The Epilogue.
|
||||
|
||||
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
|
||||
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
|
||||
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load C tensor GMEM -> RMEM
|
||||
copy(tDgC, tDrC);
|
||||
|
||||
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
|
||||
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
using AccType = typename decltype(tCtAcc)::value_type;
|
||||
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load TMEM -> RMEM
|
||||
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
|
||||
|
||||
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
|
||||
axpby(alpha, tDrAcc, beta, tDrC);
|
||||
// Store RMEM -> GMEM
|
||||
copy(tDrC, tDgD);
|
||||
}
|
||||
|
||||
template <class TypeA, class LayoutA,
|
||||
class TypeB, class LayoutB,
|
||||
class TypeC, class LayoutC,
|
||||
class TypeD, class LayoutD,
|
||||
class Alpha, class Beta>
|
||||
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
TypeB const* device_ptr_B, LayoutB layout_B,
|
||||
TypeC const* device_ptr_C, LayoutC layout_C,
|
||||
TypeD * device_ptr_D, LayoutD layout_D,
|
||||
Alpha const alpha, Beta const beta)
|
||||
{
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
|
||||
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
|
||||
|
||||
// Represent the full tensors in global memory
|
||||
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
|
||||
|
||||
// Get M, N, K dimensions of the GEMM we are running
|
||||
auto Gemm_M = shape<0>(layout_A);
|
||||
auto Gemm_N = shape<0>(layout_B);
|
||||
auto Gemm_K = shape<1>(layout_A);
|
||||
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Initialize the GEMM kernel parameters
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
|
||||
// larger TiledMma from the given mma instruction.
|
||||
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
|
||||
128, 256, // Mma M and N dimensions
|
||||
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
|
||||
|
||||
// We can also print and inspect the tiled_mma
|
||||
print(tiled_mma);
|
||||
// TiledMMA
|
||||
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
|
||||
// PermutationMNK: (_,_,_)
|
||||
// MMA_Atom
|
||||
// ThrID: _1:_0
|
||||
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
|
||||
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
|
||||
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
|
||||
|
||||
// Define MMA tiler sizes (static)
|
||||
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
|
||||
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
|
||||
|
||||
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
|
||||
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the SMEM layouts:
|
||||
//
|
||||
|
||||
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
|
||||
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
|
||||
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
|
||||
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
|
||||
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
|
||||
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
|
||||
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
|
||||
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
|
||||
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
|
||||
|
||||
// Print and inspect mma_shape_A, and mma_shape_B for this example.
|
||||
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
|
||||
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
|
||||
|
||||
// A and B tensors are swizzled in SMEM to improve MMA performance.
|
||||
// * However, expressing swizzled layouts is very hard.
|
||||
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
|
||||
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
|
||||
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
|
||||
|
||||
// Print and inspect sA_layout and sB_layout for this example.
|
||||
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
|
||||
|
||||
// The cluster shape and layout
|
||||
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Launch GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
dim3 dimBlock(128);
|
||||
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
|
||||
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<SMEMStorage,
|
||||
decltype(mA), decltype(mB), decltype(mC), decltype(mD),
|
||||
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
|
||||
Alpha, Beta>;
|
||||
|
||||
// Set kernel attributes (set SMEM)
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
|
||||
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
|
||||
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
mA, mB, mC, mD,
|
||||
mma_tiler, tiled_mma, cluster_shape,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
|
||||
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
|
||||
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int Gemm_M = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &Gemm_M);
|
||||
|
||||
int Gemm_N = 1024;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &Gemm_N);
|
||||
|
||||
int Gemm_K = 256;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &Gemm_K);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Create A, B, C, and D tensors
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
// Define the data types. A and B types are same for MMA instruction.
|
||||
using TypeA = cutlass::half_t; // MMA A Data Type
|
||||
auto type_str_a = "half_t";
|
||||
using TypeB = cutlass::half_t; // MMA B Data Type
|
||||
auto type_str_b = "half_t";
|
||||
using TypeC = float; // MMA C Data Type
|
||||
[[maybe_unused]] auto type_str_c = "float";
|
||||
using TypeD = float; // MMA D Data Type
|
||||
auto type_str_d = "float";
|
||||
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
|
||||
|
||||
// A tensor MxK K-major (Layout T = Row-Major)
|
||||
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
|
||||
// B tensor NxK K-major (Layout N = Column-Major)
|
||||
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
|
||||
// C tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
// D tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
|
||||
// Host allocations and host CuTe tensors for A, B, and C tensors.
|
||||
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
|
||||
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
|
||||
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
|
||||
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
|
||||
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
|
||||
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
|
||||
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
|
||||
|
||||
// Note that we don't need a host_tensor for D yet.
|
||||
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
|
||||
|
||||
// Initialize A, B, and C tensors with random values.
|
||||
initialize_tensor(host_tensor_A);
|
||||
initialize_tensor(host_tensor_B);
|
||||
initialize_tensor(host_tensor_C);
|
||||
|
||||
// Copy A, B, and C tensors from host memory to device memory
|
||||
thrust::device_vector<TypeA> device_A = host_A;
|
||||
thrust::device_vector<TypeB> device_B = host_B;
|
||||
thrust::device_vector<TypeC> device_C = host_C;
|
||||
|
||||
using Alpha = float;
|
||||
using Beta = float;
|
||||
Alpha alpha = 1.0f;
|
||||
Beta beta = 0.0f;
|
||||
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
|
||||
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
|
||||
device_B.data().get(), layout_B,
|
||||
device_C.data().get(), layout_C,
|
||||
device_D.data().get(), layout_D,
|
||||
alpha, beta);
|
||||
// Host allocation for D tensor and transfer D tensor from device to host
|
||||
thrust::host_vector<TypeD> host_D = device_D;
|
||||
// Create a non-owning CuTe tensor for D tensor
|
||||
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Execute reference GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
|
||||
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
|
||||
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Compare results
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
|
||||
type_str_b, host_tensor_B,
|
||||
type_str_d, host_tensor_D, host_reference_tensor_D);
|
||||
bool success = relative_error <= 0.0;
|
||||
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
671
examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu
Normal file
671
examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu
Normal file
@ -0,0 +1,671 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CuTe Tutorial for SM100 Programming
|
||||
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
|
||||
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
|
||||
//
|
||||
// The tutorial series is split into five stages:
|
||||
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
|
||||
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
|
||||
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
|
||||
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
|
||||
// Use Thrust to handle host/device allocations
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/half.h> // F16 data type
|
||||
#include <cutlass/util/print_error.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/cluster_launch.hpp>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
|
||||
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
|
||||
// Tutorial helpers
|
||||
#include "example_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tutorial 02: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
|
||||
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
|
||||
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
|
||||
// - Matrices C and D are MxN, N-major (BLAS row-major)
|
||||
//
|
||||
// This GEMM kernel extends 01_mma_sm100.cu by adding Tensor Memory Access (TMA) and performs the following steps:
|
||||
// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
||||
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
||||
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
|
||||
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
|
||||
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
|
||||
//
|
||||
// SM100 tcgen05.mma instructions operate as follows:
|
||||
// - Read matrix A from SMEM or TMEM
|
||||
// - Read matrix B from SMEM
|
||||
// - Write accumulator to TMEM
|
||||
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
//
|
||||
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
|
||||
// and the MMA's M and N dimensions.
|
||||
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
|
||||
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
|
||||
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
|
||||
//
|
||||
// The MMA details:
|
||||
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
|
||||
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
|
||||
// This example uses F16xF16 = F32 MMA where:
|
||||
// TypeA = cutlass::half_t; // MMA A Data Type
|
||||
// TypeB = cutlass::half_t; // MMA B Data Type
|
||||
// TypeC = float; // MMA C Data Type
|
||||
// TypeD = float; // MMA D Data Type
|
||||
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// The shared memory buffers for A and B matrices.
|
||||
template <class TypeA, // Tensor A data type
|
||||
class TypeB, // Tensor B data type
|
||||
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
|
||||
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
|
||||
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
|
||||
|
||||
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
|
||||
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
|
||||
|
||||
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
|
||||
};
|
||||
|
||||
// The device kernel
|
||||
template <class SharedStorage,
|
||||
class ATensor, class BTensor, class CTensor, class DTensor,
|
||||
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
|
||||
class TmaAtomA, class TmaAtomB,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
void
|
||||
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
BTensor mB, // (Gemm_N, Gemm_K)
|
||||
CTensor mC, // (Gemm_M, Gemm_N)
|
||||
DTensor mD, // (Gemm_M, Gemm_N)
|
||||
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
|
||||
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
|
||||
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
|
||||
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
|
||||
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Step 1: The Prologue.
|
||||
|
||||
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename TiledMMA::AtomThrID{}));
|
||||
|
||||
// Construct the MMA grid coordinate from the CTA grid coordinate
|
||||
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
|
||||
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
|
||||
blockIdx.y, // MMA-N coordinate
|
||||
_); // MMA-K coordinate
|
||||
|
||||
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
|
||||
// by this mma tile.
|
||||
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
|
||||
// * Tensor to partition
|
||||
// * Tiler to use for partitioning
|
||||
// * Coordinate to use for slicing the partitioned tensor
|
||||
// * Projection to ignore unwanted modes of the Tiler and Coordinate
|
||||
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
|
||||
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
|
||||
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
|
||||
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
|
||||
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
|
||||
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
|
||||
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
|
||||
|
||||
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
} __syncthreads();
|
||||
|
||||
// The SMEM tensors
|
||||
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Represent the SMEM buffers for A and B
|
||||
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
//
|
||||
// Mma partitioning for A and B
|
||||
//
|
||||
// Note: Partitioned tensors use tXgY naming convention:
|
||||
// tXgY -> The partitioning pattern tX applied to tensor gY
|
||||
|
||||
auto mma_v = get<0>(mma_coord_vmnk);
|
||||
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
|
||||
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
|
||||
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
|
||||
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// MMA Fragment Allocation
|
||||
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
|
||||
// For tcgen05.mma operations:
|
||||
// - Matrices A and B are sourced from SMEM
|
||||
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
|
||||
// - The first mode of each descriptor represents the SMEM for a single MMA operation
|
||||
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
// TMEM Allocation
|
||||
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
|
||||
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
|
||||
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// TMA Setup
|
||||
//
|
||||
// These are TMA partitionings, which have a dedicated custom partitioner.
|
||||
// The Int<0>, Layout<_1> indicates that the TMAs are not multicasted.
|
||||
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
|
||||
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
|
||||
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
|
||||
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
|
||||
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
|
||||
|
||||
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
|
||||
Int<0>{}, Layout<_1>{},
|
||||
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
|
||||
|
||||
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
|
||||
Int<0>{}, Layout<_1>{},
|
||||
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
|
||||
|
||||
// Calculate total bytes that TMA will transfer each tile to track completion
|
||||
int tma_transaction_bytes = sizeof(make_tensor_like(tAsA))
|
||||
+ sizeof(make_tensor_like(tBsB));
|
||||
|
||||
if (thread0()) {
|
||||
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
|
||||
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
|
||||
printf("TmaBytes: %d\n", tma_transaction_bytes);
|
||||
} __syncthreads();
|
||||
|
||||
// Barrier Initialization
|
||||
uint32_t elect_one_thr = cute::elect_one_sync();
|
||||
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
|
||||
|
||||
// Barriers in SMEM initialized by a single thread.
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1);
|
||||
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
|
||||
}
|
||||
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
__syncthreads(); // Make sure all threads observe barrier initialization.
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
|
||||
{
|
||||
// Step 2a: Load A and B tiles
|
||||
|
||||
// TMA Load Operations:
|
||||
// - Execute asynchronous TMA loads with single thread
|
||||
// - Set transaction bytes and execute with barrier
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
|
||||
copy(tma_atom_A.with(shared_storage.tma_barrier), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
|
||||
copy(tma_atom_B.with(shared_storage.tma_barrier), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
|
||||
}
|
||||
|
||||
// Step 2b: Execute the MMAs for this tile
|
||||
|
||||
// Wait for TMA loads to SMEM to complete
|
||||
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
|
||||
tma_barrier_phase_bit ^= 1;
|
||||
|
||||
// tcgen05.mma instructions require single-thread execution:
|
||||
// - Only one warp performs the MMA-related loop operations
|
||||
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
|
||||
// - No explicit elect_one_sync region is needed from the user
|
||||
if (elect_one_warp) {
|
||||
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
|
||||
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
|
||||
}
|
||||
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
|
||||
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
|
||||
mma_barrier_phase_bit ^= 1;
|
||||
}
|
||||
|
||||
// Step 3: The Epilogue.
|
||||
|
||||
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
|
||||
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
|
||||
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load C tensor GMEM -> RMEM
|
||||
copy(tDgC, tDrC);
|
||||
|
||||
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
|
||||
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
using AccType = typename decltype(tCtAcc)::value_type;
|
||||
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load TMEM -> RMEM
|
||||
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
|
||||
|
||||
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
|
||||
axpby(alpha, tDrAcc, beta, tDrC);
|
||||
// Store RMEM -> GMEM
|
||||
copy(tDrC, tDgD);
|
||||
}
|
||||
|
||||
template <class TypeA, class LayoutA,
|
||||
class TypeB, class LayoutB,
|
||||
class TypeC, class LayoutC,
|
||||
class TypeD, class LayoutD,
|
||||
class Alpha, class Beta>
|
||||
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
TypeB const* device_ptr_B, LayoutB layout_B,
|
||||
TypeC const* device_ptr_C, LayoutC layout_C,
|
||||
TypeD * device_ptr_D, LayoutD layout_D,
|
||||
Alpha const alpha, Beta const beta)
|
||||
{
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
|
||||
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
|
||||
|
||||
// Represent the full tensors in global memory
|
||||
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
|
||||
|
||||
// Get M, N, K dimensions of the GEMM we are running
|
||||
auto Gemm_M = shape<0>(layout_A);
|
||||
auto Gemm_N = shape<0>(layout_B);
|
||||
auto Gemm_K = shape<1>(layout_A);
|
||||
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Initialize the GEMM kernel parameters
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
|
||||
// larger TiledMma from the given mma instruction.
|
||||
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
|
||||
128, 256, // Mma M and N dimensions
|
||||
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
|
||||
|
||||
// We can also print and inspect the tiled_mma
|
||||
print(tiled_mma);
|
||||
// TiledMMA
|
||||
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
|
||||
// PermutationMNK: (_,_,_)
|
||||
// MMA_Atom
|
||||
// ThrID: _1:_0
|
||||
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
|
||||
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
|
||||
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
|
||||
|
||||
// Define MMA tiler sizes (static)
|
||||
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
|
||||
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
|
||||
|
||||
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
|
||||
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the SMEM layouts:
|
||||
//
|
||||
|
||||
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
|
||||
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
|
||||
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
|
||||
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
|
||||
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
|
||||
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
|
||||
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
|
||||
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
|
||||
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
|
||||
|
||||
// Print and inspect mma_shape_A, and mma_shape_B for this example.
|
||||
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
|
||||
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
|
||||
|
||||
// A and B tensors are swizzled in SMEM to improve MMA performance.
|
||||
// * However, expressing swizzled layouts is very hard.
|
||||
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
|
||||
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
|
||||
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
|
||||
|
||||
// Print and inspect sA_layout and sB_layout for this example.
|
||||
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
|
||||
|
||||
//
|
||||
// TMA Descriptor Creation (Host Side)
|
||||
//
|
||||
|
||||
// The cluster shape and layout
|
||||
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
|
||||
|
||||
// Create TMA descriptors for A and B matrices
|
||||
Copy_Atom tma_atom_A = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA Load Op
|
||||
mA, // Source GMEM tensor
|
||||
sA_layout, // Destination SMEM layout
|
||||
select<0,2>(mma_tiler) // MK Tiler for TMA operation
|
||||
);
|
||||
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
|
||||
|
||||
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
|
||||
// tma_atom_A: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_8192):(_0,_1)
|
||||
// ValLayoutDst: (_1,_8192):(_0,_1)
|
||||
// ValLayoutRef: (_1,_8192):(_0,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
Copy_Atom tma_atom_B = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA Load Op
|
||||
mB, // Source GMEM tensor
|
||||
sB_layout, // Destination SMEM layout
|
||||
select<1,2>(mma_tiler) // NK Tiler for TMA operation
|
||||
);
|
||||
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
|
||||
|
||||
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
|
||||
// tma_atom_B: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_16384):(_0,_1)
|
||||
// ValLayoutDst: (_1,_16384):(_0,_1)
|
||||
// ValLayoutRef: (_1,_16384):(_0,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Launch GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
dim3 dimBlock(128);
|
||||
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
|
||||
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<SMEMStorage,
|
||||
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
|
||||
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
|
||||
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
|
||||
Alpha, Beta>;
|
||||
|
||||
// Set kernel attributes (set SMEM)
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
|
||||
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
|
||||
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
mA_tma, mB_tma, mC, mD,
|
||||
mma_tiler, tiled_mma, cluster_shape,
|
||||
tma_atom_A, tma_atom_B,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
|
||||
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
|
||||
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int Gemm_M = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &Gemm_M);
|
||||
|
||||
int Gemm_N = 1024;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &Gemm_N);
|
||||
|
||||
int Gemm_K = 256;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &Gemm_K);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Create A, B, C, and D tensors
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
// Define the data types. A and B types are same for MMA instruction.
|
||||
using TypeA = cutlass::half_t; // MMA A Data Type
|
||||
auto type_str_a = "half_t";
|
||||
using TypeB = cutlass::half_t; // MMA B Data Type
|
||||
auto type_str_b = "half_t";
|
||||
using TypeC = float; // MMA C Data Type
|
||||
[[maybe_unused]] auto type_str_c = "float";
|
||||
using TypeD = float; // MMA D Data Type
|
||||
auto type_str_d = "float";
|
||||
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
|
||||
|
||||
// A tensor MxK K-major (Layout T = Row-Major)
|
||||
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
|
||||
// B tensor NxK K-major (Layout N = Column-Major)
|
||||
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
|
||||
// C tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
// D tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
|
||||
// Host allocations and host CuTe tensors for A, B, and C tensors.
|
||||
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
|
||||
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
|
||||
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
|
||||
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
|
||||
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
|
||||
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
|
||||
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
|
||||
|
||||
// Note that we don't need a host_tensor for D yet.
|
||||
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
|
||||
|
||||
// Initialize A, B, and C tensors with random values.
|
||||
initialize_tensor(host_tensor_A);
|
||||
initialize_tensor(host_tensor_B);
|
||||
initialize_tensor(host_tensor_C);
|
||||
|
||||
// Copy A, B, and C tensors from host memory to device memory
|
||||
thrust::device_vector<TypeA> device_A = host_A;
|
||||
thrust::device_vector<TypeB> device_B = host_B;
|
||||
thrust::device_vector<TypeC> device_C = host_C;
|
||||
|
||||
using Alpha = float;
|
||||
using Beta = float;
|
||||
Alpha alpha = 1.0f;
|
||||
Beta beta = 0.0f;
|
||||
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
|
||||
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
|
||||
device_B.data().get(), layout_B,
|
||||
device_C.data().get(), layout_C,
|
||||
device_D.data().get(), layout_D,
|
||||
alpha, beta);
|
||||
// Host allocation for D tensor and transfer D tensor from device to host
|
||||
thrust::host_vector<TypeD> host_D = device_D;
|
||||
// Create a non-owning CuTe tensor for D tensor
|
||||
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Execute reference GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
|
||||
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
|
||||
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Compare results
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
|
||||
type_str_b, host_tensor_B,
|
||||
type_str_d, host_tensor_D, host_reference_tensor_D);
|
||||
bool success = relative_error <= 0.0;
|
||||
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
711
examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu
Normal file
711
examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu
Normal file
@ -0,0 +1,711 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CuTe Tutorial for SM100 Programming
|
||||
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
|
||||
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
|
||||
//
|
||||
// The tutorial series is split into five stages:
|
||||
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
|
||||
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
|
||||
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
|
||||
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
|
||||
// Use Thrust to handle host/device allocations
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/half.h> // F16 data type
|
||||
#include <cutlass/util/print_error.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/cluster_launch.hpp>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
|
||||
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
|
||||
// Tutorial helpers
|
||||
#include "example_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tutorial 03: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
|
||||
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
|
||||
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
|
||||
// - Matrices C and D are MxN, N-major (BLAS row-major)
|
||||
//
|
||||
// Key extensions from tutorial 02_mma_tma_sm100.cu:
|
||||
// 1. Introduce ClusterShape for coordinated execution across thread blocks
|
||||
// 2. Introduce TMA multicast
|
||||
// 3. Enhanced TMA <-> MMA synchronization for cluster-wide operations
|
||||
//
|
||||
// This GEMM kernel will perform the following steps:
|
||||
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA load operations.
|
||||
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
||||
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
|
||||
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
|
||||
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
|
||||
//
|
||||
// SM100 tcgen05.mma instructions operate as follows:
|
||||
// - Read matrix A from SMEM or TMEM
|
||||
// - Read matrix B from SMEM
|
||||
// - Write accumulator to TMEM
|
||||
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
//
|
||||
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
|
||||
// and the MMA's M and N dimensions.
|
||||
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
|
||||
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
|
||||
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
|
||||
//
|
||||
// The MMA details:
|
||||
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
|
||||
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
|
||||
// This example uses F16xF16 = F32 MMA where:
|
||||
// TypeA = cutlass::half_t; // MMA A Data Type
|
||||
// TypeB = cutlass::half_t; // MMA B Data Type
|
||||
// TypeC = float; // MMA C Data Type
|
||||
// TypeD = float; // MMA D Data Type
|
||||
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// The shared memory buffers for A and B matrices.
|
||||
template <class TypeA, // Tensor A data type
|
||||
class TypeB, // Tensor B data type
|
||||
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
|
||||
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
|
||||
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
|
||||
|
||||
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
|
||||
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
|
||||
|
||||
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
|
||||
};
|
||||
|
||||
// The device kernel
|
||||
template <class SharedStorage,
|
||||
class ATensor, class BTensor, class CTensor, class DTensor,
|
||||
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
|
||||
class TmaAtomA, class TmaAtomB,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
void
|
||||
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
BTensor mB, // (Gemm_N, Gemm_K)
|
||||
CTensor mC, // (Gemm_M, Gemm_N)
|
||||
DTensor mD, // (Gemm_M, Gemm_N)
|
||||
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
|
||||
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
|
||||
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
|
||||
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
|
||||
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Step 1: The Prologue.
|
||||
|
||||
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename TiledMMA::AtomThrID{}));
|
||||
|
||||
// Construct the MMA grid coordinate from the CTA grid coordinate
|
||||
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
|
||||
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
|
||||
blockIdx.y, // MMA-N coordinate
|
||||
_); // MMA-K coordinate
|
||||
|
||||
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
|
||||
// by this mma tile.
|
||||
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
|
||||
// * Tensor to partition
|
||||
// * Tiler to use for partitioning
|
||||
// * Coordinate to use for slicing the partitioned tensor
|
||||
// * Projection to ignore unwanted modes of the Tiler and Coordinate
|
||||
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
|
||||
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
|
||||
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
|
||||
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
|
||||
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
|
||||
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
|
||||
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
|
||||
|
||||
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
} __syncthreads();
|
||||
|
||||
// The SMEM tensors
|
||||
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Represent the SMEM buffers for A and B
|
||||
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
//
|
||||
// Mma partitioning for A and B
|
||||
//
|
||||
|
||||
auto mma_v = get<0>(mma_coord_vmnk);
|
||||
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
|
||||
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
|
||||
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
|
||||
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// MMA Fragment Allocation
|
||||
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
|
||||
// For tcgen05.mma operations:
|
||||
// - Matrices A and B are sourced from SMEM
|
||||
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
|
||||
// - The first mode of each descriptor represents the SMEM for a single MMA operation
|
||||
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
// TMEM Allocation
|
||||
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
|
||||
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
|
||||
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// TMA Setup
|
||||
//
|
||||
// These are TMA partitionings, which have a dedicated custom partitioner.
|
||||
// In this example, the TMA multicasts the loads across multiple CTAs.
|
||||
// Loads of A are multicasted along the N dimension of the cluster_shape_MNK and
|
||||
// Loads of B are multicasted along the M dimension of the cluster_shape_MNK.
|
||||
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
|
||||
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
|
||||
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
|
||||
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
|
||||
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
|
||||
|
||||
// Each CTA with the same m-coord will load a portion of A
|
||||
// Each CTA with the same n-coord will load a portion of B
|
||||
// Multicast behavior for CTA 1,2 in the cluster
|
||||
// A multicast B multicast
|
||||
// 0 1 2 3 0 1 2 3
|
||||
// 0 - - - - 0 - - X -
|
||||
// 1 X X X X 1 - - X -
|
||||
// 2 - - - - 2 - - X -
|
||||
// 3 - - - - 3 - - X -
|
||||
// tma_multicast_mask_A = 0x2222
|
||||
// tma_multicast_mask_B = 0x0F00
|
||||
// mma_multicast_mask_C = 0x2F22
|
||||
|
||||
// Construct the CTA-in-Cluster coordinate for multicasting
|
||||
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
|
||||
|
||||
// Project the cluster_layout for tma_A along the N-modes
|
||||
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
|
||||
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
|
||||
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
|
||||
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
|
||||
|
||||
// Project the cluster_layout for tma_B along the M-modes
|
||||
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
|
||||
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
|
||||
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
|
||||
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
|
||||
|
||||
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
|
||||
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
|
||||
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
|
||||
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
|
||||
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
|
||||
// Calculate total bytes that TMA will transfer each tile to track completion
|
||||
int tma_transaction_bytes = sizeof(make_tensor_like(tAsA))
|
||||
+ sizeof(make_tensor_like(tBsB));
|
||||
|
||||
if (thread0()) {
|
||||
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
|
||||
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
|
||||
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
|
||||
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
|
||||
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
|
||||
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
|
||||
} __syncthreads();
|
||||
|
||||
// Barrier Initialization
|
||||
|
||||
uint32_t elect_one_thr = cute::elect_one_sync();
|
||||
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
|
||||
|
||||
// Barriers in SMEM initialized by a single thread.
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
|
||||
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
|
||||
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
|
||||
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
|
||||
}
|
||||
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
cute::cluster_sync(); // Make sure all threads across all CTAs in Cluster observe barrier initialization.
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
|
||||
{
|
||||
// Step 2a: Load A and B tiles
|
||||
|
||||
// TMA Load Operations:
|
||||
// - Execute asynchronous TMA loads with single thread
|
||||
// - Set transaction bytes and execute with barrier
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
|
||||
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
|
||||
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
|
||||
}
|
||||
|
||||
// Step 2b: Execute the MMAs for this tile
|
||||
|
||||
// Wait for TMA loads to SMEM to complete
|
||||
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
|
||||
tma_barrier_phase_bit ^= 1;
|
||||
|
||||
// tcgen05.mma instructions require single-thread execution:
|
||||
// - Only one warp performs the MMA-related loop operations
|
||||
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
|
||||
// - No explicit elect_one_sync region is needed from the user
|
||||
if (elect_one_warp) {
|
||||
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
|
||||
cutlass::arch::umma_arrive_multicast(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
|
||||
}
|
||||
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
|
||||
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
|
||||
mma_barrier_phase_bit ^= 1;
|
||||
}
|
||||
|
||||
// Step 3: The Epilogue.
|
||||
|
||||
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
|
||||
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
|
||||
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load C tensor GMEM -> RMEM
|
||||
copy(tDgC, tDrC);
|
||||
|
||||
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
|
||||
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
using AccType = typename decltype(tCtAcc)::value_type;
|
||||
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load TMEM -> RMEM
|
||||
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
|
||||
|
||||
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
|
||||
axpby(alpha, tDrAcc, beta, tDrC);
|
||||
// Store RMEM -> GMEM
|
||||
copy(tDrC, tDgD);
|
||||
}
|
||||
|
||||
template <class TypeA, class LayoutA,
|
||||
class TypeB, class LayoutB,
|
||||
class TypeC, class LayoutC,
|
||||
class TypeD, class LayoutD,
|
||||
class Alpha, class Beta>
|
||||
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
TypeB const* device_ptr_B, LayoutB layout_B,
|
||||
TypeC const* device_ptr_C, LayoutC layout_C,
|
||||
TypeD * device_ptr_D, LayoutD layout_D,
|
||||
Alpha const alpha, Beta const beta)
|
||||
{
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
|
||||
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
|
||||
|
||||
// Represent the full tensors in global memory
|
||||
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
|
||||
|
||||
// Get M, N, K dimensions of the GEMM we are running
|
||||
auto Gemm_M = shape<0>(layout_A);
|
||||
auto Gemm_N = shape<0>(layout_B);
|
||||
auto Gemm_K = shape<1>(layout_A);
|
||||
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Initialize the GEMM kernel parameters
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
|
||||
// larger TiledMma from the given mma instruction.
|
||||
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
|
||||
128, 256, // Mma M and N dimensions
|
||||
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
|
||||
|
||||
// We can also print and inspect the tiled_mma
|
||||
print(tiled_mma);
|
||||
// TiledMMA
|
||||
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
|
||||
// PermutationMNK: (_,_,_)
|
||||
// MMA_Atom
|
||||
// ThrID: _1:_0
|
||||
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
|
||||
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
|
||||
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
|
||||
|
||||
// Define MMA tiler sizes (static)
|
||||
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
|
||||
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
|
||||
|
||||
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
|
||||
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the SMEM layouts:
|
||||
//
|
||||
|
||||
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
|
||||
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
|
||||
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
|
||||
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
|
||||
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
|
||||
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
|
||||
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
|
||||
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
|
||||
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
|
||||
|
||||
// Print and inspect mma_shape_A, and mma_shape_B for this example.
|
||||
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
|
||||
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
|
||||
|
||||
// A and B tensors are swizzled in SMEM to improve MMA performance.
|
||||
// * However, expressing swizzled layouts is very hard.
|
||||
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
|
||||
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
|
||||
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
|
||||
|
||||
// Print and inspect sA_layout and sB_layout for this example.
|
||||
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
|
||||
|
||||
//
|
||||
// TMA Descriptor Creation (Host Side)
|
||||
//
|
||||
|
||||
// The cluster shape and layout
|
||||
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
|
||||
|
||||
Copy_Atom tma_atom_A = make_tma_atom(
|
||||
SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast
|
||||
mA, // Source GMEM tensor
|
||||
sA_layout, // Destination SMEM layout
|
||||
select<0,2>(mma_tiler), // MK Tiler for TMA operation
|
||||
size<2>(cluster_layout_vmnk) // The number of CTAs in the N-mode for multicasting
|
||||
);
|
||||
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
|
||||
|
||||
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
|
||||
// tma_atom_A: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_8192):(_0,_1)
|
||||
// ValLayoutDst: (_1,_8192):(_0,_1)
|
||||
// ValLayoutRef: (_1,_8192):(_0,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
Copy_Atom tma_atom_B = make_tma_atom(
|
||||
SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast
|
||||
mB, // Source GMEM tensor
|
||||
sB_layout, // Destination SMEM layout
|
||||
select<1,2>(mma_tiler), // NK Tiler for TMA operation
|
||||
size<1>(cluster_layout_vmnk) // The number of CTAs in the M-mode for multicasting
|
||||
);
|
||||
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
|
||||
|
||||
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
|
||||
// tma_atom_B: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_16384):(_0,_1)
|
||||
// ValLayoutDst: (_1,_16384):(_0,_1)
|
||||
// ValLayoutRef: (_1,_16384):(_0,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Launch GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
dim3 dimBlock(128);
|
||||
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
|
||||
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<SMEMStorage,
|
||||
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
|
||||
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
|
||||
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
|
||||
Alpha, Beta>;
|
||||
|
||||
// Set kernel attributes (set SMEM)
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
|
||||
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
|
||||
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
mA_tma, mB_tma, mC, mD,
|
||||
mma_tiler, tiled_mma, cluster_shape,
|
||||
tma_atom_A, tma_atom_B,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
|
||||
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
|
||||
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int Gemm_M = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &Gemm_M);
|
||||
|
||||
int Gemm_N = 1024;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &Gemm_N);
|
||||
|
||||
int Gemm_K = 256;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &Gemm_K);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Create A, B, C, and D tensors
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
// Define the data types. A and B types are same for MMA instruction.
|
||||
using TypeA = cutlass::half_t; // MMA A Data Type
|
||||
auto type_str_a = "half_t";
|
||||
using TypeB = cutlass::half_t; // MMA B Data Type
|
||||
auto type_str_b = "half_t";
|
||||
using TypeC = float; // MMA C Data Type
|
||||
[[maybe_unused]] auto type_str_c = "float";
|
||||
using TypeD = float; // MMA D Data Type
|
||||
auto type_str_d = "float";
|
||||
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
|
||||
|
||||
// A tensor MxK K-major (Layout T = Row-Major)
|
||||
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
|
||||
// B tensor NxK K-major (Layout N = Column-Major)
|
||||
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
|
||||
// C tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
// D tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
|
||||
// Host allocations and host CuTe tensors for A, B, and C tensors.
|
||||
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
|
||||
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
|
||||
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
|
||||
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
|
||||
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
|
||||
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
|
||||
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
|
||||
|
||||
// Note that we don't need a host_tensor for D yet.
|
||||
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
|
||||
|
||||
// Initialize A, B, and C tensors with random values.
|
||||
initialize_tensor(host_tensor_A);
|
||||
initialize_tensor(host_tensor_B);
|
||||
initialize_tensor(host_tensor_C);
|
||||
|
||||
// Copy A, B, and C tensors from host memory to device memory
|
||||
thrust::device_vector<TypeA> device_A = host_A;
|
||||
thrust::device_vector<TypeB> device_B = host_B;
|
||||
thrust::device_vector<TypeC> device_C = host_C;
|
||||
|
||||
using Alpha = float;
|
||||
using Beta = float;
|
||||
Alpha alpha = 1.0f;
|
||||
Beta beta = 0.0f;
|
||||
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
|
||||
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
|
||||
device_B.data().get(), layout_B,
|
||||
device_C.data().get(), layout_C,
|
||||
device_D.data().get(), layout_D,
|
||||
alpha, beta);
|
||||
// Host allocation for D tensor and transfer D tensor from device to host
|
||||
thrust::host_vector<TypeD> host_D = device_D;
|
||||
// Create a non-owning CuTe tensor for D tensor
|
||||
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Execute reference GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
|
||||
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
|
||||
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Compare results
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
|
||||
type_str_b, host_tensor_B,
|
||||
type_str_d, host_tensor_D, host_reference_tensor_D);
|
||||
bool success = relative_error <= 0.0;
|
||||
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
716
examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu
Normal file
716
examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu
Normal file
@ -0,0 +1,716 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CuTe Tutorial for SM100 Programming
|
||||
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
|
||||
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
|
||||
//
|
||||
// The tutorial series is split into five stages:
|
||||
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
|
||||
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
|
||||
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
|
||||
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
|
||||
// Use Thrust to handle host/device allocations
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/half.h> // F16 data type
|
||||
#include <cutlass/util/print_error.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/cluster_launch.hpp>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
|
||||
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
|
||||
// Tutorial helpers
|
||||
#include "example_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tutorial 04: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
|
||||
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
|
||||
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
|
||||
// - Matrices C and D are MxN, N-major (BLAS row-major)
|
||||
//
|
||||
// Key extensions to tutorial 03_mma_tma_multicast_sm100.cu:
|
||||
// 1. Introduce 2SM tcgen05.mma instructions
|
||||
// 2. Introduce 2SM TMA instructions
|
||||
// 3. Demonstrate TMA multicast pattern specialized for 2SM instructions for loading A and B matrices
|
||||
//
|
||||
// This GEMM kernel will perform the following steps:
|
||||
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations.
|
||||
// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction.
|
||||
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
|
||||
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
|
||||
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
|
||||
//
|
||||
// SM100 2SM tcgen05.mma instructions operate as follows:
|
||||
// - Mma is launched by only one SM
|
||||
// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction.
|
||||
// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA.
|
||||
// - Read matrix A from SMEM or TMEM
|
||||
// - Read matrix B from SMEM
|
||||
// - Write accumulator to TMEM
|
||||
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
//
|
||||
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
|
||||
// and the MMA's M and N dimensions.
|
||||
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
|
||||
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
|
||||
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
|
||||
//
|
||||
// The MMA details:
|
||||
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA
|
||||
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
|
||||
// This example uses F16xF16 = F32 MMA where:
|
||||
// TypeA = cutlass::half_t; // MMA A Data Type
|
||||
// TypeB = cutlass::half_t; // MMA B Data Type
|
||||
// TypeC = float; // MMA C Data Type
|
||||
// TypeD = float; // MMA D Data Type
|
||||
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// The shared memory buffers for A and B matrices.
|
||||
template <class TypeA, // Tensor A data type
|
||||
class TypeB, // Tensor B data type
|
||||
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
|
||||
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
|
||||
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
|
||||
|
||||
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
|
||||
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
|
||||
|
||||
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
|
||||
};
|
||||
|
||||
// The device kernel
|
||||
template <class SharedStorage,
|
||||
class ATensor, class BTensor, class CTensor, class DTensor,
|
||||
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
|
||||
class TmaAtomA, class TmaAtomB,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
void
|
||||
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
BTensor mB, // (Gemm_N, Gemm_K)
|
||||
CTensor mC, // (Gemm_M, Gemm_N)
|
||||
DTensor mD, // (Gemm_M, Gemm_N)
|
||||
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
|
||||
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
|
||||
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
|
||||
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
|
||||
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Step 1: The Prologue.
|
||||
|
||||
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename TiledMMA::AtomThrID{}));
|
||||
|
||||
// Construct the MMA grid coordinate from the CTA grid coordinate
|
||||
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
|
||||
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
|
||||
blockIdx.y, // MMA-N coordinate
|
||||
_); // MMA-K coordinate
|
||||
|
||||
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
|
||||
// by this mma tile.
|
||||
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
|
||||
// * Tensor to partition
|
||||
// * Tiler to use for partitioning
|
||||
// * Coordinate to use for slicing the partitioned tensor
|
||||
// * Projection to ignore unwanted modes of the Tiler and Coordinate
|
||||
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
|
||||
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
|
||||
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
|
||||
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
|
||||
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
|
||||
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
|
||||
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
|
||||
|
||||
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
} __syncthreads();
|
||||
|
||||
// The SMEM tensors
|
||||
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Represent the SMEM buffers for A and B
|
||||
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
//
|
||||
// Mma partitioning for A and B
|
||||
//
|
||||
|
||||
auto mma_v = get<0>(mma_coord_vmnk);
|
||||
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
|
||||
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
|
||||
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
|
||||
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// MMA Fragment Allocation
|
||||
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
|
||||
// For tcgen05.mma operations:
|
||||
// - Matrices A and B are sourced from SMEM
|
||||
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
|
||||
// - The first mode of each descriptor represents the SMEM for a single MMA operation
|
||||
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
// TMEM Allocation
|
||||
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
|
||||
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
|
||||
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// TMA Setup
|
||||
//
|
||||
// These are TMA partitionings, which have a dedicated custom partitioner.
|
||||
// In this example, the TMA multicasts the loads across multiple CTAs.
|
||||
// Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and
|
||||
// Loads of B are multicasted along the M dimension of the cluster_shape_VMNK.
|
||||
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
|
||||
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
|
||||
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
|
||||
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
|
||||
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
|
||||
|
||||
// Each CTA with the same m-coord will load a portion of A
|
||||
// Each CTA with the same n-coord will load a portion of B
|
||||
// Computation of the multicast masks must take into account the Peer CTA for TMA.2SM
|
||||
|
||||
// Construct the CTA-in-Cluster coordinate for multicasting
|
||||
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
|
||||
|
||||
// Project the cluster_layout for tma_A along the N-modes
|
||||
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
|
||||
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
|
||||
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
|
||||
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
|
||||
|
||||
// Project the cluster_layout for tma_B along the M-modes
|
||||
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
|
||||
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
|
||||
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
|
||||
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
|
||||
|
||||
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
|
||||
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
|
||||
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
|
||||
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
|
||||
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
|
||||
// Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM
|
||||
int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA))
|
||||
+ size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB));
|
||||
|
||||
if (thread0()) {
|
||||
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
|
||||
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
|
||||
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
|
||||
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
|
||||
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
|
||||
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
|
||||
} __syncthreads();
|
||||
|
||||
// Barrier Initialization
|
||||
auto elect_one_thr = cute::elect_one_sync();
|
||||
auto elect_one_warp = (threadIdx.x / 32 == 0);
|
||||
auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{};
|
||||
|
||||
// Barriers in SMEM should be initialized by a single thread.
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
|
||||
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
|
||||
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
|
||||
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
|
||||
}
|
||||
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc.
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
|
||||
{
|
||||
// Step 2a: Load A and B tiles
|
||||
|
||||
// TMA Load Operations:
|
||||
// - Execute asynchronous TMA loads with single thread
|
||||
// - Both peer and leader CTAs initiate TMA loads
|
||||
// - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs.
|
||||
// - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA.
|
||||
// - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast.
|
||||
if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread
|
||||
if (elect_one_cta) { // Only the leader CTA waits for TMA transactions
|
||||
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads
|
||||
}
|
||||
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
|
||||
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
|
||||
}
|
||||
|
||||
// Step 2b: Execute the MMAs for this tile
|
||||
|
||||
if (elect_one_cta) {
|
||||
// Wait for TMA loads to complete on leader CTAs
|
||||
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
|
||||
tma_barrier_phase_bit ^= 1;
|
||||
|
||||
// tcgen05.mma instructions require single-thread execution:
|
||||
// - Only one warp performs the MMA-related loop operations
|
||||
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
|
||||
// - No explicit elect_one_sync region is needed from the user
|
||||
if (elect_one_warp) {
|
||||
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
|
||||
}
|
||||
}
|
||||
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
|
||||
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
|
||||
mma_barrier_phase_bit ^= 1;
|
||||
}
|
||||
|
||||
// Step 3: The Epilogue.
|
||||
|
||||
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
|
||||
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
|
||||
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load C tensor GMEM -> RMEM
|
||||
copy(tDgC, tDrC);
|
||||
|
||||
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
|
||||
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
using AccType = typename decltype(tCtAcc)::value_type;
|
||||
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
|
||||
// Load TMEM -> RMEM
|
||||
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
|
||||
|
||||
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
|
||||
axpby(alpha, tDrAcc, beta, tDrC);
|
||||
// Store RMEM -> GMEM
|
||||
copy(tDrC, tDgD);
|
||||
}
|
||||
|
||||
template <class TypeA, class LayoutA,
|
||||
class TypeB, class LayoutB,
|
||||
class TypeC, class LayoutC,
|
||||
class TypeD, class LayoutD,
|
||||
class Alpha, class Beta>
|
||||
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
TypeB const* device_ptr_B, LayoutB layout_B,
|
||||
TypeC const* device_ptr_C, LayoutC layout_C,
|
||||
TypeD * device_ptr_D, LayoutD layout_D,
|
||||
Alpha const alpha, Beta const beta)
|
||||
{
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
|
||||
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
|
||||
|
||||
// Represent the full tensors in global memory
|
||||
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
|
||||
|
||||
// Get M, N, K dimensions of the GEMM we are running
|
||||
auto Gemm_M = shape<0>(layout_A);
|
||||
auto Gemm_N = shape<0>(layout_B);
|
||||
auto Gemm_K = shape<1>(layout_A);
|
||||
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Initialize the GEMM kernel parameters
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
|
||||
// larger TiledMma from the given mma instruction.
|
||||
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
|
||||
256, 256, // Mma M and N dimensions
|
||||
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
|
||||
|
||||
// We can also print and inspect the tiled_mma
|
||||
print(tiled_mma);
|
||||
// TiledMMA
|
||||
// ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0)
|
||||
// PermutationMNK: (_,_,_)
|
||||
// MMA_Atom
|
||||
// ThrID: _2:_1
|
||||
// Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.)
|
||||
// LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix
|
||||
// LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
// LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
|
||||
// Define MMA tiler sizes (static)
|
||||
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
|
||||
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
|
||||
|
||||
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
|
||||
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the SMEM layouts:
|
||||
//
|
||||
|
||||
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
|
||||
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
|
||||
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
|
||||
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
|
||||
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
|
||||
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
|
||||
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
|
||||
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
|
||||
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
|
||||
|
||||
// Print and inspect mma_shape_A, and mma_shape_B for this example.
|
||||
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
|
||||
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
|
||||
|
||||
// A and B tensors are swizzled in SMEM to improve MMA performance.
|
||||
// * However, expressing swizzled layouts is very hard.
|
||||
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
|
||||
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
|
||||
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
|
||||
|
||||
// Print and inspect sA_layout and sB_layout for this example.
|
||||
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
|
||||
|
||||
//
|
||||
// TMA Descriptor Creation (Host Side)
|
||||
//
|
||||
|
||||
// The cluster shape and layout
|
||||
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
|
||||
|
||||
// SM100 interface for creating TMA loads.
|
||||
Copy_Atom tma_atom_A = make_tma_atom_A_sm100(
|
||||
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
|
||||
mA, // Source GMEM tensor
|
||||
sA_layout, // Destination SMEM layout
|
||||
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
|
||||
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
|
||||
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
|
||||
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
|
||||
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
|
||||
|
||||
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
|
||||
// tma_atom_A: Copy_Atom
|
||||
// ThrID: _2:_1
|
||||
// ValLayoutSrc: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutDst: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutRef: (_2,_8192):(_8192,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
// SM100 interface for creating TMA loads.
|
||||
Copy_Atom tma_atom_B = make_tma_atom_B_sm100(
|
||||
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
|
||||
mB, // Source GMEM tensor
|
||||
sB_layout, // Destination SMEM layout
|
||||
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
|
||||
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
|
||||
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
|
||||
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
|
||||
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
|
||||
|
||||
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
|
||||
// tma_atom_B: Copy_Atom
|
||||
// ThrID: _2:_1
|
||||
// ValLayoutSrc: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutDst: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutRef: (_2,_8192):(_8192,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Launch GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
dim3 dimBlock(128);
|
||||
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
|
||||
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<SMEMStorage,
|
||||
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
|
||||
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
|
||||
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
|
||||
Alpha, Beta>;
|
||||
|
||||
// Set kernel attributes (set SMEM)
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
|
||||
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
|
||||
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
mA_tma, mB_tma, mC, mD,
|
||||
mma_tiler, tiled_mma, cluster_shape,
|
||||
tma_atom_A, tma_atom_B,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
|
||||
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
|
||||
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int Gemm_M = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &Gemm_M);
|
||||
|
||||
int Gemm_N = 1024;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &Gemm_N);
|
||||
|
||||
int Gemm_K = 256;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &Gemm_K);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Create A, B, C, and D tensors
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
// Define the data types. A and B types are same for MMA instruction.
|
||||
using TypeA = cutlass::half_t; // MMA A Data Type
|
||||
auto type_str_a = "half_t";
|
||||
using TypeB = cutlass::half_t; // MMA B Data Type
|
||||
auto type_str_b = "half_t";
|
||||
using TypeC = float; // MMA C Data Type
|
||||
[[maybe_unused]] auto type_str_c = "float";
|
||||
using TypeD = float; // MMA D Data Type
|
||||
auto type_str_d = "float";
|
||||
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
|
||||
|
||||
// A tensor MxK K-major (Layout T = Row-Major)
|
||||
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
|
||||
// B tensor NxK K-major (Layout N = Column-Major)
|
||||
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
|
||||
// C tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
// D tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
|
||||
// Host allocations and host CuTe tensors for A, B, and C tensors.
|
||||
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
|
||||
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
|
||||
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
|
||||
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
|
||||
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
|
||||
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
|
||||
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
|
||||
|
||||
// Note that we don't need a host_tensor for D yet.
|
||||
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
|
||||
|
||||
// Initialize A, B, and C tensors with random values.
|
||||
initialize_tensor(host_tensor_A);
|
||||
initialize_tensor(host_tensor_B);
|
||||
initialize_tensor(host_tensor_C);
|
||||
|
||||
// Copy A, B, and C tensors from host memory to device memory
|
||||
thrust::device_vector<TypeA> device_A = host_A;
|
||||
thrust::device_vector<TypeB> device_B = host_B;
|
||||
thrust::device_vector<TypeC> device_C = host_C;
|
||||
|
||||
using Alpha = float;
|
||||
using Beta = float;
|
||||
Alpha alpha = 1.0f;
|
||||
Beta beta = 0.0f;
|
||||
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
|
||||
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
|
||||
device_B.data().get(), layout_B,
|
||||
device_C.data().get(), layout_C,
|
||||
device_D.data().get(), layout_D,
|
||||
alpha, beta);
|
||||
// Host allocation for D tensor and transfer D tensor from device to host
|
||||
thrust::host_vector<TypeD> host_D = device_D;
|
||||
// Create a non-owning CuTe tensor for D tensor
|
||||
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Execute reference GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
|
||||
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
|
||||
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Compare results
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
|
||||
type_str_b, host_tensor_B,
|
||||
type_str_d, host_tensor_D, host_reference_tensor_D);
|
||||
bool success = relative_error <= 0.0;
|
||||
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
825
examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu
Normal file
825
examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu
Normal file
@ -0,0 +1,825 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CuTe Tutorial for SM100 Programming
|
||||
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
|
||||
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
|
||||
//
|
||||
// The tutorial series is split into five stages:
|
||||
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
|
||||
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
|
||||
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
|
||||
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
|
||||
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
|
||||
// Use Thrust to handle host/device allocations
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/half.h> // F16 data type
|
||||
#include <cutlass/util/print_error.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/cluster_launch.hpp>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
|
||||
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
|
||||
#include <cute/algorithm/cooperative_copy.hpp>
|
||||
|
||||
// Tutorial helpers
|
||||
#include "example_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tutorial 05: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
|
||||
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
|
||||
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
|
||||
// - Matrices C and D are MxN, N-major (BLAS row-major)
|
||||
//
|
||||
// Key extensions to tutorial 04_mma_tma_2sm_sm100.cu:
|
||||
// 1. Demonstrate using TMA instructions in the epilogue
|
||||
//
|
||||
// This GEMM kernel will perform the following steps:
|
||||
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations.
|
||||
// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction.
|
||||
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||||
// 4. Read C matrix from global memory (GMEM) to shared memory (SMEM) with TMA.
|
||||
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
|
||||
// 6. Store D matrix from shared memory (SMEM) to global memory (GMEM) with TMA.
|
||||
//
|
||||
// SM100 2SM tcgen05.mma instructions operate as follows:
|
||||
// - Mma is launched by only one SM
|
||||
// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction.
|
||||
// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA.
|
||||
// - Read matrix A from SMEM or TMEM
|
||||
// - Read matrix B from SMEM
|
||||
// - Write accumulator to TMEM
|
||||
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||||
//
|
||||
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
|
||||
// and the MMA's M and N dimensions.
|
||||
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
|
||||
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
|
||||
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
|
||||
//
|
||||
// The MMA details:
|
||||
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA
|
||||
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
|
||||
// This example uses F16xF16 = F32 MMA where:
|
||||
// TypeA = cutlass::half_t; // MMA A Data Type
|
||||
// TypeB = cutlass::half_t; // MMA B Data Type
|
||||
// TypeC = float; // MMA C Data Type
|
||||
// TypeD = float; // MMA D Data Type
|
||||
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// The shared memory buffers for A, B, C, and D matrices.
|
||||
template <class TypeA, // Tensor A data type
|
||||
class TypeB, // Tensor B data type
|
||||
class TypeC, // Tensor C data type
|
||||
class TypeD, // Tensor D data type
|
||||
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
|
||||
class BSmemLayout, // (MmaB, NumMma_N, NumMma_K, ...)
|
||||
class CSmemLayout, // EpiTile_M, EpiTile_N
|
||||
class DSmemLayout> // EpiTile_M, EpiTile_N
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) union {
|
||||
alignas(128) struct {
|
||||
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
|
||||
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
|
||||
} mainloop;
|
||||
alignas(128) cute::ArrayEngine<TypeC, cute::cosize_v<CSmemLayout>> C;
|
||||
alignas(128) cute::ArrayEngine<TypeD, cute::cosize_v<DSmemLayout>> D;
|
||||
} tensors;
|
||||
|
||||
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
|
||||
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
|
||||
|
||||
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(tensors.mainloop.A.begin()), ASmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(tensors.mainloop.B.begin()), BSmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sC() { return make_tensor(make_smem_ptr(tensors.C.begin()), CSmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sD() { return make_tensor(make_smem_ptr(tensors.D.begin()), DSmemLayout{}); }
|
||||
};
|
||||
|
||||
// The device kernel
|
||||
template <class SharedStorage,
|
||||
class ATensor, class BTensor, class CTensor, class DTensor,
|
||||
class MmaTiler_MNK, class EpiTiler_MN, class TiledMMA, class ClusterShape_MNK,
|
||||
class TmaAtomA, class TmaAtomB, class TmaAtomC, class TmaAtomD,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
void
|
||||
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
BTensor mB, // (Gemm_N, Gemm_K)
|
||||
CTensor mC, // (Gemm_M, Gemm_N)
|
||||
DTensor mD, // (Gemm_M, Gemm_N)
|
||||
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
|
||||
EpiTiler_MN epi_tiler_mn, // <EpiTile_M, EpiTile_N>
|
||||
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
|
||||
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
|
||||
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
|
||||
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
|
||||
CUTE_GRID_CONSTANT TmaAtomC const tma_atom_C,
|
||||
CUTE_GRID_CONSTANT TmaAtomD const tma_atom_D,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Step 1: The Prologue.
|
||||
|
||||
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename TiledMMA::AtomThrID{}));
|
||||
|
||||
// Construct the MMA grid coordinate from the CTA grid coordinate
|
||||
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
|
||||
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
|
||||
blockIdx.y, // MMA-N coordinate
|
||||
_); // MMA-K coordinate
|
||||
|
||||
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
|
||||
// by this mma tile.
|
||||
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
|
||||
// * Tensor to partition
|
||||
// * Tiler to use for partitioning
|
||||
// * Coordinate to use for slicing the partitioned tensor
|
||||
// * Projection to ignore unwanted modes of the Tiler and Coordinate
|
||||
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
|
||||
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
|
||||
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
|
||||
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
|
||||
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
|
||||
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
|
||||
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
|
||||
|
||||
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
|
||||
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
|
||||
} __syncthreads();
|
||||
|
||||
// The SMEM tensors
|
||||
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// Represent the SMEM buffers for A and B
|
||||
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
//
|
||||
// Mma partitioning for A and B
|
||||
//
|
||||
|
||||
auto mma_v = get<0>(mma_coord_vmnk);
|
||||
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
|
||||
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
|
||||
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
|
||||
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
|
||||
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// MMA Fragment Allocation
|
||||
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
|
||||
// For tcgen05.mma operations:
|
||||
// - Matrices A and B are sourced from SMEM
|
||||
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
|
||||
// - The first mode of each descriptor represents the SMEM for a single MMA operation
|
||||
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
|
||||
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
|
||||
|
||||
// TMEM Allocation
|
||||
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
|
||||
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
|
||||
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
|
||||
|
||||
if (thread0()) {
|
||||
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
|
||||
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
|
||||
} __syncthreads();
|
||||
|
||||
// TMA Setup
|
||||
//
|
||||
// These are TMA partitionings, which have a dedicated custom partitioner.
|
||||
// In this example, the TMA multicasts the loads across multiple CTAs.
|
||||
// Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and
|
||||
// Loads of B are multicasted along the M dimension of the cluster_shape_VMNK.
|
||||
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
|
||||
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
|
||||
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
|
||||
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
|
||||
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
|
||||
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
|
||||
|
||||
// Each CTA with the same m-coord will load a portion of A
|
||||
// Each CTA with the same n-coord will load a portion of B
|
||||
// Computation of the multicast masks must take into account the Peer CTA for TMA.2SM
|
||||
|
||||
// Construct the CTA-in-Cluster coordinate for multicasting
|
||||
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
|
||||
|
||||
// Project the cluster_layout for tma_A along the N-modes
|
||||
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
|
||||
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
|
||||
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
|
||||
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
|
||||
|
||||
// Project the cluster_layout for tma_B along the M-modes
|
||||
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
|
||||
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
|
||||
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
|
||||
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
|
||||
|
||||
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
|
||||
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
|
||||
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
|
||||
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
|
||||
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
|
||||
|
||||
// Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM
|
||||
int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA))
|
||||
+ size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB));
|
||||
|
||||
if (thread0()) {
|
||||
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
|
||||
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
|
||||
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
|
||||
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
|
||||
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
|
||||
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
|
||||
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
|
||||
} __syncthreads();
|
||||
|
||||
// Barrier Initialization
|
||||
auto elect_one_thr = cute::elect_one_sync();
|
||||
auto elect_one_warp = (threadIdx.x / 32 == 0);
|
||||
auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{};
|
||||
|
||||
// Barriers in SMEM should be initialized by a single thread.
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
|
||||
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
|
||||
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
|
||||
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
|
||||
}
|
||||
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
|
||||
cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc.
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
|
||||
{
|
||||
// Step 2a: Load A and B tiles
|
||||
|
||||
// TMA Load Operations:
|
||||
// - Execute asynchronous TMA loads with single thread
|
||||
// - Both peer and leader CTAs initiate TMA loads
|
||||
// - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs.
|
||||
// - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA.
|
||||
// - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast.
|
||||
if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread
|
||||
if (elect_one_cta) { // Only the leader CTA waits for TMA transactions
|
||||
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads
|
||||
}
|
||||
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
|
||||
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
|
||||
}
|
||||
|
||||
// Step 2b: Execute the MMAs for this tile
|
||||
|
||||
if (elect_one_cta) {
|
||||
// Wait for TMA loads to complete on leader CTAs
|
||||
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
|
||||
tma_barrier_phase_bit ^= 1;
|
||||
|
||||
// tcgen05.mma instructions require single-thread execution:
|
||||
// - Only one warp performs the MMA-related loop operations
|
||||
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
|
||||
// - No explicit elect_one_sync region is needed from the user
|
||||
if (elect_one_warp) {
|
||||
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
|
||||
}
|
||||
}
|
||||
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
|
||||
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
|
||||
mma_barrier_phase_bit ^= 1;
|
||||
}
|
||||
|
||||
// Step 3: The Epilogue.
|
||||
|
||||
// Apply rank-2 epilogue tiler to rank-2 MMA_V mode
|
||||
auto epi_tiler_v = make_tile(epi_tiler_mn); // (EpiTile)
|
||||
Tensor tAcc_epi = zipped_divide(tCtAcc, epi_tiler_v); // (EpiTile,NumTiles)
|
||||
Tensor gC_epi = zipped_divide(tCgC, epi_tiler_v); // (EpiTile,NumTiles)
|
||||
Tensor gD_epi = zipped_divide(tCgD, epi_tiler_v); // (EpiTile,NumTiles)
|
||||
|
||||
// Construct corresponding SMEM tensors
|
||||
Tensor sC_epi = shared_storage.tensor_sC(); // (EpiTile)
|
||||
Tensor sD_epi = shared_storage.tensor_sD(); // (EpiTile)
|
||||
|
||||
// Partition for TMA
|
||||
auto [tGS_gC, tGS_sC] = tma_partition(tma_atom_C, sC_epi, gC_epi); // (GMEM -> SMEM)
|
||||
auto [tSG_gD, tSG_sD] = tma_partition(tma_atom_D, sD_epi, gD_epi); // (SMEM -> GMEM)
|
||||
|
||||
// Reset transaction bytes for C load
|
||||
tma_transaction_bytes = sizeof(make_tensor_like(tGS_sC));
|
||||
|
||||
// Partition for TMEM accumulators load (TMEM -> RMEM)
|
||||
TiledCopy t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tAcc_epi(_,_0{}));
|
||||
ThrCopy thr_t2r = t2r_copy.get_slice(threadIdx.x);
|
||||
Tensor tTR_tAcc = thr_t2r.partition_S(tAcc_epi); // (TmemCpy,NumTmemCpy,NumTiles)
|
||||
Tensor tTR_sC = thr_t2r.partition_D(sC_epi); // (TmemCpy,NumTmemCpy)
|
||||
Tensor tTR_sD = thr_t2r.partition_D(sD_epi); // (TmemCpy,NumTmemCpy)
|
||||
// Allocate register tensors
|
||||
Tensor tTR_rC = make_tensor_like(tTR_sC); // (TmemCpy,NumTmemCpy)
|
||||
Tensor tTR_rD = make_fragment_like(tTR_sD); // (TmemCpy,NumTmemCpy)
|
||||
|
||||
// Loop over the epilogue tiles
|
||||
CUTE_UNROLL
|
||||
for (int epi_tile_idx = 0; epi_tile_idx < size<2>(tTR_tAcc); ++epi_tile_idx) {
|
||||
// TMA Load C: GMEM -> SMEM
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
|
||||
copy(tma_atom_C.with(shared_storage.tma_barrier, 0 /*no multicast*/), tGS_gC(_,epi_tile_idx), tGS_sC);
|
||||
}
|
||||
// All threads wait for C TMA load to complete
|
||||
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
|
||||
tma_barrier_phase_bit ^= 1;
|
||||
|
||||
// Load C: SMEM -> RMEM
|
||||
copy_aligned(tTR_sC, tTR_rC);
|
||||
|
||||
// Load Acc: TMEM -> RMEM
|
||||
copy(t2r_copy, tTR_tAcc(_,_,epi_tile_idx), tTR_rD);
|
||||
|
||||
// Compute D = beta * C + alpha * (A*B)
|
||||
axpby(beta, tTR_rC, alpha, tTR_rD);
|
||||
|
||||
// Store D: RMEM -> SMEM
|
||||
__syncthreads(); // Ensure C loads are finished before reusing smem (unnecessary if smem layouts match)
|
||||
copy_aligned(tTR_rD, tTR_sD);
|
||||
|
||||
// TMA Store D: SMEM -> GMEM
|
||||
tma_store_fence(); // Ensure D smem stores are visible to TMA
|
||||
__syncthreads(); // Ensure all threads have issued fence
|
||||
if (elect_one_warp && elect_one_thr) {
|
||||
copy(tma_atom_D, tSG_sD, tSG_gD(_,epi_tile_idx));
|
||||
tma_store_arrive(); // issuing thread commits D TMA store
|
||||
tma_store_wait<0>(); // issuing thread waits for D TMA store to complete
|
||||
}
|
||||
__syncthreads(); // All threads sync with issuing thread
|
||||
}
|
||||
}
|
||||
|
||||
template <class TypeA, class LayoutA,
|
||||
class TypeB, class LayoutB,
|
||||
class TypeC, class LayoutC,
|
||||
class TypeD, class LayoutD,
|
||||
class Alpha, class Beta>
|
||||
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
TypeB const* device_ptr_B, LayoutB layout_B,
|
||||
TypeC const* device_ptr_C, LayoutC layout_C,
|
||||
TypeD * device_ptr_D, LayoutD layout_D,
|
||||
Alpha const alpha, Beta const beta)
|
||||
{
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
|
||||
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
|
||||
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
|
||||
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
|
||||
|
||||
// Represent the full tensors in global memory
|
||||
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
|
||||
|
||||
// Get M, N, K dimensions of the GEMM we are running
|
||||
auto Gemm_M = shape<0>(layout_A);
|
||||
auto Gemm_N = shape<0>(layout_B);
|
||||
auto Gemm_K = shape<1>(layout_A);
|
||||
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Initialize the GEMM kernel parameters
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
|
||||
// larger TiledMma from the given mma instruction.
|
||||
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
|
||||
256, 256, // Mma M and N dimensions
|
||||
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
|
||||
|
||||
// We can also print and inspect the tiled_mma
|
||||
print(tiled_mma);
|
||||
// TiledMMA
|
||||
// ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0)
|
||||
// PermutationMNK: (_,_,_)
|
||||
// MMA_Atom
|
||||
// ThrID: _2:_1
|
||||
// Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.)
|
||||
// LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix
|
||||
// LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
// LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
|
||||
|
||||
// Define MMA tiler sizes (static)
|
||||
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
|
||||
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
|
||||
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
|
||||
|
||||
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
|
||||
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the SMEM layouts:
|
||||
//
|
||||
|
||||
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
|
||||
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
|
||||
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
|
||||
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
|
||||
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
|
||||
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
|
||||
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
|
||||
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
|
||||
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
|
||||
|
||||
// Print and inspect mma_shape_A, and mma_shape_B for this example.
|
||||
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
|
||||
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
|
||||
|
||||
// A and B tensors are swizzled in SMEM to improve MMA performance.
|
||||
// * However, expressing swizzled layouts is very hard.
|
||||
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
|
||||
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
|
||||
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
|
||||
|
||||
// Print and inspect sA_layout and sB_layout for this example.
|
||||
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
|
||||
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
|
||||
|
||||
//
|
||||
// Epilogue parameters
|
||||
//
|
||||
|
||||
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_N) to post-partitioned ((MmaM,MmaN), NumMma_M, NumMma_N)
|
||||
auto mma_shape_C = partition_shape_C(tiled_mma, make_shape(size<0>(mma_tiler), size<1>(mma_tiler)));
|
||||
|
||||
// For TMA epilogue performance it may be beneficial to iterate over the output in smaller tiles than the MMA tile
|
||||
auto epi_tiler = make_tile(size<0,0>(mma_shape_C), size<0,1>(mma_shape_C) / Int<4>{}); // 4 TMA copies per CTA per MMA tile
|
||||
|
||||
// SMEM layouts for C and D should match the epilogue tile
|
||||
auto sC_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom<TypeC>{}, // MMA K-major is equivalent to epilogue N-major
|
||||
make_shape(size<0>(epi_tiler), size<1>(epi_tiler)));
|
||||
auto sC_layout = group<0,2>(sC_layout_mn); // Group modes for tma_partition
|
||||
|
||||
auto sD_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom<TypeD>{}, // MMA K-major is equivalent to epilogue N-major
|
||||
make_shape(size<0>(epi_tiler), size<1>(epi_tiler)));
|
||||
auto sD_layout = group<0,2>(sD_layout_mn); // Group modes for tma_partition
|
||||
|
||||
print("sC_layout:\t"); print(sC_layout); print("\n"); // sC_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096))
|
||||
print("sD_layout:\t"); print(sD_layout); print("\n"); // sD_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096))
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeA, TypeB, TypeC, TypeD,
|
||||
decltype(sA_layout), decltype(sB_layout),
|
||||
decltype(sC_layout), decltype(sD_layout)>;
|
||||
|
||||
//
|
||||
// TMA Descriptor Creation (Host Side)
|
||||
//
|
||||
|
||||
// The cluster shape and layout
|
||||
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
|
||||
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
|
||||
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
|
||||
|
||||
// SM100 interface for creating TMA loads.
|
||||
Copy_Atom tma_atom_A = make_tma_atom_A_sm100(
|
||||
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
|
||||
mA, // Source GMEM tensor
|
||||
sA_layout, // Destination SMEM layout
|
||||
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
|
||||
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
|
||||
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
|
||||
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
|
||||
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
|
||||
|
||||
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
|
||||
// tma_atom_A: Copy_Atom
|
||||
// ThrID: _2:_1
|
||||
// ValLayoutSrc: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutDst: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutRef: (_2,_8192):(_8192,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
// SM100 interface for creating TMA loads.
|
||||
Copy_Atom tma_atom_B = make_tma_atom_B_sm100(
|
||||
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
|
||||
mB, // Source GMEM tensor
|
||||
sB_layout, // Destination SMEM layout
|
||||
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
|
||||
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
|
||||
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
|
||||
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
|
||||
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
|
||||
|
||||
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
|
||||
// tma_atom_B: Copy_Atom
|
||||
// ThrID: _2:_1
|
||||
// ValLayoutSrc: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutDst: (_2,_8192):(_8192,_1)
|
||||
// ValLayoutRef: (_2,_8192):(_8192,_1)
|
||||
// ValueType: 16b
|
||||
|
||||
Copy_Atom tma_atom_C = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA load operation
|
||||
mC, // Source GMEM tensor
|
||||
sC_layout, // Destination SMEM layout
|
||||
epi_tiler); // MN Tiler for epilogue
|
||||
Tensor mC_tma = tma_atom_C.get_tma_tensor(shape(mC)); // (Gemm_M, Gemm_N)
|
||||
|
||||
print("tma_atom_C:\t"); print(tma_atom_C); print("\n");
|
||||
// tma_atom_C: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_4096):(_0,_1)
|
||||
// ValLayoutDst: (_1,_4096):(_0,_1)
|
||||
// ValLayoutRef: (_1,_4096):(_0,_1)
|
||||
// ValueType: 32b
|
||||
|
||||
Copy_Atom tma_atom_D = make_tma_atom(
|
||||
SM90_TMA_STORE{}, // TMA store operation
|
||||
mD, // Destination GMEM tensor
|
||||
sD_layout, // Source SMEM layout
|
||||
epi_tiler); // MN Tiler for epilogue
|
||||
Tensor mD_tma = tma_atom_D.get_tma_tensor(shape(mD)); // (Gemm_M, Gemm_N)
|
||||
|
||||
print("tma_atom_D:\t"); print(tma_atom_D); print("\n");
|
||||
// tma_atom_D: Copy_Atom
|
||||
// ThrID: _1:_0
|
||||
// ValLayoutSrc: (_1,_4096):(_0,_1)
|
||||
// ValLayoutDst: (_1,_4096):(_0,_1)
|
||||
// ValLayoutRef: (_1,_4096):(_0,_1)
|
||||
// ValueType: 32b
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Launch GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
dim3 dimBlock(128);
|
||||
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
|
||||
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<SMEMStorage,
|
||||
decltype(mA_tma), decltype(mB_tma), decltype(mC_tma), decltype(mD_tma),
|
||||
decltype(mma_tiler), decltype(epi_tiler), decltype(tiled_mma), decltype(cluster_shape),
|
||||
decltype(tma_atom_A), decltype(tma_atom_B), decltype(tma_atom_C), decltype(tma_atom_D), // Includes the TMA descriptor.
|
||||
Alpha, Beta>;
|
||||
|
||||
// Set kernel attributes (set SMEM)
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
|
||||
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
|
||||
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
mA_tma, mB_tma, mC_tma, mD_tma,
|
||||
mma_tiler, epi_tiler, tiled_mma, cluster_shape,
|
||||
tma_atom_A, tma_atom_B, tma_atom_C, tma_atom_D,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
|
||||
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
|
||||
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
int Gemm_M = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &Gemm_M);
|
||||
|
||||
int Gemm_N = 1024;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &Gemm_N);
|
||||
|
||||
int Gemm_K = 256;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &Gemm_K);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Create A, B, C, and D tensors
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
// Define the data types. A and B types are same for MMA instruction.
|
||||
using TypeA = cutlass::half_t; // MMA A Data Type
|
||||
auto type_str_a = "half_t";
|
||||
using TypeB = cutlass::half_t; // MMA B Data Type
|
||||
auto type_str_b = "half_t";
|
||||
using TypeC = float; // MMA C Data Type
|
||||
[[maybe_unused]] auto type_str_c = "float";
|
||||
using TypeD = float; // MMA D Data Type
|
||||
auto type_str_d = "float";
|
||||
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
|
||||
|
||||
// A tensor MxK K-major (Layout T = Row-Major)
|
||||
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
|
||||
// B tensor NxK K-major (Layout N = Column-Major)
|
||||
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
|
||||
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
|
||||
// C tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
// D tensor MxN N-major (Layout T = Row-Major)
|
||||
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
|
||||
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
|
||||
|
||||
// Host allocations and host CuTe tensors for A, B, and C tensors.
|
||||
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
|
||||
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
|
||||
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
|
||||
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
|
||||
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
|
||||
|
||||
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
|
||||
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
|
||||
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
|
||||
|
||||
// Note that we don't need a host_tensor for D yet.
|
||||
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
|
||||
|
||||
// Initialize A, B, and C tensors with random values.
|
||||
initialize_tensor(host_tensor_A);
|
||||
initialize_tensor(host_tensor_B);
|
||||
initialize_tensor(host_tensor_C);
|
||||
|
||||
// Copy A, B, and C tensors from host memory to device memory
|
||||
thrust::device_vector<TypeA> device_A = host_A;
|
||||
thrust::device_vector<TypeB> device_B = host_B;
|
||||
thrust::device_vector<TypeC> device_C = host_C;
|
||||
|
||||
using Alpha = float;
|
||||
using Beta = float;
|
||||
Alpha alpha = 1.0f;
|
||||
Beta beta = 0.0f;
|
||||
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
|
||||
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
|
||||
device_B.data().get(), layout_B,
|
||||
device_C.data().get(), layout_C,
|
||||
device_D.data().get(), layout_D,
|
||||
alpha, beta);
|
||||
// Host allocation for D tensor and transfer D tensor from device to host
|
||||
thrust::host_vector<TypeD> host_D = device_D;
|
||||
// Create a non-owning CuTe tensor for D tensor
|
||||
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Execute reference GEMM kernel
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
|
||||
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
|
||||
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Compare results
|
||||
//
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
|
||||
type_str_b, host_tensor_B,
|
||||
type_str_d, host_tensor_D, host_reference_tensor_D);
|
||||
bool success = relative_error <= 0.0;
|
||||
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
54
examples/cute/tutorial/blackwell/CMakeLists.txt
Normal file
54
examples/cute/tutorial/blackwell/CMakeLists.txt
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_01_mma_sm100
|
||||
01_mma_sm100.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_02_mma_tma_sm100
|
||||
02_mma_tma_sm100.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_03_mma_tma_multicast_sm100
|
||||
03_mma_tma_multicast_sm100.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_04_mma_tma_2sm_sm100
|
||||
04_mma_tma_2sm_sm100.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_05_mma_tma_epi_sm100
|
||||
05_mma_tma_epi_sm100.cu
|
||||
)
|
||||
endif()
|
||||
105
examples/cute/tutorial/blackwell/example_utils.hpp
Normal file
105
examples/cute/tutorial/blackwell/example_utils.hpp
Normal file
@ -0,0 +1,105 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
template <class AccType,
|
||||
class TensorA, class TensorB,
|
||||
class TensorC, class TensorD,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
reference_gemm(TensorA const& tensor_A, TensorB const& tensor_B,
|
||||
TensorC const& tensor_C, TensorD & tensor_D,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
for (int m = 0; m < size<0>(tensor_D); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor_D); ++n) {
|
||||
AccType c = AccType(0.f);
|
||||
for (int k = 0; k < size<1>(tensor_A); ++k) {
|
||||
c += tensor_A(m,k) * tensor_B(n,k);
|
||||
}
|
||||
tensor_D(m,n) = alpha * c + beta * tensor_C(m,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TensorA, class TensorB,
|
||||
class TensorC, class TensorD,
|
||||
class RefTensorD>
|
||||
bool
|
||||
compare_results(TensorA const& tensor_A, TensorB const& tensor_B,
|
||||
TensorC const& tensor_C, TensorD const& tensor_D,
|
||||
RefTensorD const& ref_tensor_D,
|
||||
bool print_diff = false)
|
||||
{
|
||||
using namespace cute;
|
||||
auto norm_A = matrix_inf_norm(tensor_A);
|
||||
auto norm_B = matrix_inf_norm(tensor_B);
|
||||
auto norm_C = matrix_inf_norm(tensor_C);
|
||||
auto norm_D = matrix_inf_norm(tensor_D);
|
||||
auto norm_ref_D = matrix_inf_norm(ref_tensor_D);
|
||||
auto norm_diff = matrix_diff_inf_norm(tensor_D, ref_tensor_D);
|
||||
|
||||
if (print_diff) {
|
||||
for (int m = 0; m < size<0>(tensor_D); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor_D); ++n) {
|
||||
std::cout << m << "," << n << " : " << tensor_D(m,n) << " vs. " << ref_tensor_D(m,n) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "norm (A) : " << norm_A.inf_norm << std::endl;
|
||||
std::cout << "norm (B) : " << norm_B.inf_norm << std::endl;
|
||||
std::cout << "norm (C) : " << norm_C.inf_norm << std::endl;
|
||||
std::cout << "norm (D) : " << norm_D.inf_norm << std::endl;
|
||||
std::cout << "norm (ref_D) : " << norm_ref_D.inf_norm << std::endl;
|
||||
std::cout << "norm (D-ref_D) : " << norm_diff.inf_norm << std::endl;
|
||||
|
||||
return (!norm_A.found_nan) && (!norm_B.found_nan) &&
|
||||
(!norm_C.found_nan) && (!norm_D.found_nan) && (!norm_ref_D.found_nan) && // There are no NaNs
|
||||
(norm_A.inf_norm > 0.0) && (norm_B.inf_norm > 0.0) &&
|
||||
(norm_C.inf_norm > 0.0) && (norm_D.inf_norm > 0.0) && (norm_ref_D.inf_norm > 0.0) && // Values in tensors aren't zeros
|
||||
(norm_diff.inf_norm <= 0.0); // Diff (ref_D-D) == 0
|
||||
}
|
||||
|
||||
template <class Tensor>
|
||||
void
|
||||
initialize_tensor(Tensor& tensor, cute::tuple<int, int> value_range = {-2, 2})
|
||||
{
|
||||
using DataType = typename Tensor::element_type;
|
||||
auto [min, max] = value_range;
|
||||
for (int i = 0; i < cute::size(tensor); i++) {
|
||||
tensor(i) = DataType(int((max-min)*(rand() / double(RAND_MAX)) + min));
|
||||
}
|
||||
}
|
||||
38
examples/cute/tutorial/hopper/CMakeLists.txt
Normal file
38
examples/cute/tutorial/hopper/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_wgmma_sm90
|
||||
wgmma_sm90.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_wgmma_tma_sm90
|
||||
wgmma_tma_sm90.cu
|
||||
)
|
||||
611
examples/cute/tutorial/hopper/wgmma_sm90.cu
Normal file
611
examples/cute/tutorial/hopper/wgmma_sm90.cu
Normal file
@ -0,0 +1,611 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <class ElementA,
|
||||
class ElementB,
|
||||
class SmemLayoutA, // (M,K,P)
|
||||
class SmemLayoutB> // (N,K,P)
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(128) cute::ArrayEngine<ElementA, cosize_v<SmemLayoutA>> A;
|
||||
alignas(128) cute::ArrayEngine<ElementB, cosize_v<SmemLayoutB>> B;
|
||||
};
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB,
|
||||
class TC, class CStride, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
|
||||
TC * C, CStride dC, TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory tensors
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
|
||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), ASmemLayout{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), BSmemLayout{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||
Tensor sA_ = as_position_independent_swizzle_tensor(sA);
|
||||
Tensor tAsA = thr_copy_a.partition_D(sA_); // (CPY,CPY_M,CPY_K,PIPE)
|
||||
|
||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||
Tensor sB_ = as_position_independent_swizzle_tensor(sB);
|
||||
Tensor tBsB = thr_copy_b.partition_D(sB_); // (CPY,CPY_N,CPY_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
// auto K_PIPE_MAX = size<3>(tAsA);
|
||||
|
||||
// // Total count of tiles
|
||||
// int k_tile_count = size<3>(tAgA);
|
||||
// // Current tile index in gmem to read from
|
||||
// int k_tile_next = 0;
|
||||
|
||||
// // Start async loads for all pipes but the last
|
||||
// CUTE_UNROLL
|
||||
// for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) {
|
||||
// copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe));
|
||||
// copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe));
|
||||
// cp_async_fence();
|
||||
// --k_tile_count;
|
||||
// if (k_tile_count > 0) { ++k_tile_next; }
|
||||
// }
|
||||
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate registers for pipelining
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrA : "); print(tCrA); print("\n");
|
||||
print("tCrB : "); print(tCrB); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// Total number of k-tiles
|
||||
auto K_TILE_MAX = size<3>(tAgA);
|
||||
// Number of pipelined k-tiles in smem
|
||||
auto K_PIPE_MAX = size<3>(tAsA);
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
// Prefetch all but the last
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < K_PIPE_MAX-1; ++k)
|
||||
{
|
||||
copy(copy_a, tAgA(_,_,_,k), tAsA(_,_,_,k));
|
||||
copy(copy_b, tBgB(_,_,_,k), tBsB(_,_,_,k));
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
|
||||
// Current pipe to read from
|
||||
int k_pipe_read = 0;
|
||||
// Current pipe to write to
|
||||
int k_pipe_write = K_PIPE_MAX-1;
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
|
||||
{
|
||||
int k_tile_next = k_tile + (K_PIPE_MAX-1);
|
||||
k_tile_next = (k_tile_next >= K_TILE_MAX) ? K_TILE_MAX-1 : k_tile_next;
|
||||
|
||||
//
|
||||
// Copy gmem to smem for k_tile_write
|
||||
//
|
||||
|
||||
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe_write));
|
||||
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe_write));
|
||||
cp_async_fence();
|
||||
|
||||
// Advance k_pipe_write
|
||||
++k_pipe_write;
|
||||
k_pipe_write = (k_pipe_write == K_PIPE_MAX) ? 0 : k_pipe_write;
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
// Wait on all cp.async -- optimize by pipelining to overlap GMEM reads
|
||||
cp_async_wait<0>();
|
||||
|
||||
warpgroup_fence_operand(tCrC);
|
||||
warpgroup_arrive();
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(mma, tCrA(_,_,_,k_pipe_read), tCrB(_,_,_,k_pipe_read), tCrC);
|
||||
warpgroup_commit_batch();
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(tCrC);
|
||||
|
||||
// Advance k_pipe_read
|
||||
++k_pipe_read;
|
||||
k_pipe_read = (k_pipe_read == K_PIPE_MAX) ? 0 : k_pipe_read;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define NT strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
|
||||
auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
|
||||
Layout<Shape<_16,_8>>{}, // Thr layout 32x4 m-major
|
||||
Layout<Shape< _8,_1>>{});// Val layout 8x1 m-major
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
|
||||
Layout<Shape<_16,_8>>{}, // Thr layout 32x4 n-major
|
||||
Layout<Shape< _8,_1>>{});// Val layout 8x1 n-major
|
||||
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
//
|
||||
// Setup and Launch
|
||||
//
|
||||
|
||||
// Launch parameter setup
|
||||
dim3 dimBlock(size(tiled_mma));
|
||||
dim3 dimCluster(1, 1, 1);
|
||||
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(n, bN)), dimCluster.y));
|
||||
int smemBytes = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
|
||||
|
||||
auto* kernel_ptr = &gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(dA), decltype(sA), decltype(copyA),
|
||||
TB, decltype(dB), decltype(sB), decltype(copyB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>;
|
||||
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
// Kernel Launch
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, tiled_mma,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup params for a TN GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
|
||||
auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
|
||||
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
|
||||
Layout<Shape< _1,_8>>{}); // Val layout 1x8
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
|
||||
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
|
||||
Layout<Shape< _1,_8>>{}); // Val layout 1x8
|
||||
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::K,GMMA::Major::K>{});
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
//
|
||||
// Setup and Launch
|
||||
//
|
||||
|
||||
// Launch parameter setup
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(tiled_mma));
|
||||
dim3 dimCluster(1, 1, 1);
|
||||
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(n, bN)), dimCluster.y));
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
|
||||
|
||||
void const* kernel_ptr = reinterpret_cast<void const*>(
|
||||
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(dA), decltype(sA), decltype(copyA),
|
||||
TB, decltype(dB), decltype(sB), decltype(copyB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>);
|
||||
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
|
||||
kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
|
||||
// Kernel Launch
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
|
||||
prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, tiled_mma,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 8) {
|
||||
std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
|
||||
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
|
||||
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = cute::half_t;
|
||||
using TB = cute::half_t;
|
||||
using TC = cute::half_t;
|
||||
using TI = cute::half_t;
|
||||
|
||||
TI alpha = TI(1.0f);
|
||||
TI beta = TI(0.0f);
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
// Initialize the tensors
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1));
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1));
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = TC(0);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -55,8 +55,8 @@ template <class ElementA,
|
||||
class SmemLayoutB> // (N,K,P)
|
||||
struct SharedStorage
|
||||
{
|
||||
array_aligned<ElementA, cosize_v<SmemLayoutA>> smem_A;
|
||||
array_aligned<ElementB, cosize_v<SmemLayoutB>> smem_B;
|
||||
alignas(128) cute::ArrayEngine<ElementA, cosize_v<SmemLayoutA>> A;
|
||||
alignas(128) cute::ArrayEngine<ElementB, cosize_v<SmemLayoutB>> B;
|
||||
|
||||
uint64_t tma_barrier[size<2>(SmemLayoutA{})];
|
||||
uint64_t mma_barrier[size<2>(SmemLayoutA{})];
|
||||
@ -110,8 +110,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<TA, TB, SmemLayoutA, SmemLayoutB>;
|
||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles
|
||||
@ -132,8 +132,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE)
|
||||
|
||||
// The TMA is responsible for copying everything in mode-0 of tAsA and tBsB
|
||||
constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) +
|
||||
CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB);
|
||||
constexpr int tma_transaction_bytes = sizeof(make_tensor_like(tensor<0>(tAsA)))
|
||||
+ sizeof(make_tensor_like(tensor<0>(tBsB)));
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
@ -171,7 +171,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
// Set expected Tx Bytes after each reset / init
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes);
|
||||
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
|
||||
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
|
||||
}
|
||||
@ -242,7 +242,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
// Wait for Consumer to complete consumption
|
||||
ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase());
|
||||
// Set expected Tx Bytes after each reset / init
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes);
|
||||
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
|
||||
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
|
||||
++write_state;
|
||||
@ -393,27 +393,25 @@ gemm_tn(int m, int n, int k,
|
||||
//
|
||||
|
||||
// Launch parameter setup
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(tiled_mma));
|
||||
dim3 dimCluster(2, 1, 1);
|
||||
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(n, bN)), dimCluster.y));
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
|
||||
int smemBytes = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
|
||||
|
||||
void const* kernel_ptr = reinterpret_cast<void const*>(
|
||||
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(sA), decltype(tmaA),
|
||||
TB, decltype(sB), decltype(tmaB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>);
|
||||
auto* kernel_ptr = &gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(sA), decltype(tmaA),
|
||||
TB, decltype(sB), decltype(tmaB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>;
|
||||
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
|
||||
kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smemBytes));
|
||||
|
||||
// Kernel Launch
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
|
||||
prob_shape, cta_tiler,
|
||||
A, tmaA,
|
||||
B, tmaB,
|
||||
@ -448,8 +446,10 @@ gemm(char transA, char transB, int m, int n, int k,
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
cudaGetDevice(¤t_device_id);
|
||||
cudaGetDeviceProperties(&props, current_device_id);
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
@ -461,7 +461,7 @@ int main(int argc, char** argv)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
|
||||
|
||||
int m = 512;
|
||||
if (argc >= 2)
|
||||
@ -553,10 +553,8 @@ int main(int argc, char** argv)
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#else
|
||||
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
}
|
||||
@ -41,17 +41,27 @@
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class ElementA,
|
||||
class ElementB,
|
||||
class SmemLayoutA,
|
||||
class SmemLayoutB>
|
||||
struct SharedStorage
|
||||
{
|
||||
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
|
||||
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
|
||||
};
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA, class S2RAtomA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB, class S2RAtomB,
|
||||
class TC, class CStride, class CSmemLayout, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, S2RAtomA s2r_atom_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, S2RAtomB s2r_atom_b,
|
||||
TC * C, CStride dC, CSmemLayout , TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
@ -95,10 +105,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ASmemLayout>];
|
||||
__shared__ TB smemB[cosize_v<BSmemLayout>];
|
||||
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE)
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
|
||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
@ -143,26 +154,35 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
//
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate registers for pipelining
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K)
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
|
||||
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
|
||||
|
||||
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
|
||||
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
@ -187,12 +207,15 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrA : "); print(tCrA); print("\n");
|
||||
print("tCrB : "); print(tCrB); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
|
||||
print("tXsA : "); print(tXsA); print("\n");
|
||||
print("tXrA : "); print(tXrA); print("\n");
|
||||
print("tXsB : "); print(tXsB); print("\n");
|
||||
print("tXrB : "); print(tXrB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -204,8 +227,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
int smem_pipe_write = K_PIPE_MAX-1;
|
||||
|
||||
// Pipe slice
|
||||
Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
Tensor tXsA_p = tXsA(_,_,_,smem_pipe_read);
|
||||
Tensor tXsB_p = tXsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Size of the register pipeline
|
||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||
@ -217,8 +240,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{}));
|
||||
copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{}));
|
||||
copy(s2r_atom_a, tXsA_p(_,_,Int<0>{}), tXrA(_,_,Int<0>{}));
|
||||
copy(s2r_atom_b, tXsB_p(_,_,Int<0>{}), tXrB(_,_,Int<0>{}));
|
||||
}
|
||||
|
||||
//
|
||||
@ -243,8 +266,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// Slice the smem_pipe_read smem
|
||||
tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
tXsA_p = tXsA(_,_,_,smem_pipe_read);
|
||||
tXsB_p = tXsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Commit the smem for smem_pipe_read
|
||||
cp_async_wait<K_PIPE_MAX-2>();
|
||||
@ -253,8 +276,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||
copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next));
|
||||
copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next));
|
||||
copy(s2r_atom_a, tXsA_p(_,_,k_block_next), tXrA(_,_,k_block_next));
|
||||
copy(s2r_atom_b, tXsB_p(_,_,k_block_next), tXrB(_,_,k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
@ -268,8 +291,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
|
||||
// Advance the smem pipe
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX-1) ? 0 : smem_pipe_read+1;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
@ -286,6 +308,126 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
template <class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
cute::half_t const* A, int ldA,
|
||||
cute::half_t const* B, int ldB,
|
||||
Beta beta,
|
||||
cute::half_t * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
// Setup params for a TN HGEMM
|
||||
template <class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
cute::half_t const* A, int ldA,
|
||||
cute::half_t const* B, int ldB,
|
||||
Beta beta,
|
||||
cute::half_t * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
// Swizzles for LDSM and 128b k-major loads
|
||||
auto swizzle_atom = composition(Swizzle<3,3,3>{},
|
||||
Layout<Shape <_8,Shape <_8, _8>>,
|
||||
Stride<_8,Stride<_1,_64>>>{});
|
||||
|
||||
auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
|
||||
auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
|
||||
auto sC = make_layout(make_shape(bM, bN));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
|
||||
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
|
||||
Layout<Shape< _1,_8>>{}); // Val layout 1x8 k-major
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
|
||||
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
|
||||
Layout<Shape< _1,_8>>{}); // Val layout 1x8 n-major
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
|
||||
Layout<Shape<_2,_2>>{}, // 2x2x1 MMA Atoms
|
||||
Tile<_32,_32,_16>{}); // 32x32x16 Tiled MMA for LDSM
|
||||
|
||||
//Copy_Atom<DefaultCopy, half_t> s2r_atom_A;
|
||||
//Copy_Atom<UniversalCopy<half_t>, half_t> s2r_atom_A;
|
||||
//Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_A;
|
||||
//Copy_Atom<SM75_U32x2_LDSM_N, half_t> s2r_atom_A;
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, half_t> s2r_atom_A;
|
||||
|
||||
//Copy_Atom<DefaultCopy, half_t> s2r_atom_B;
|
||||
//Copy_Atom<UniversalCopy<half_t>, half_t> s2r_atom_B;
|
||||
//Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_B;
|
||||
//Copy_Atom<SM75_U32x2_LDSM_N, half_t> s2r_atom_B;
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, half_t> s2r_atom_B;
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
int smem_size = int(sizeof(SharedStorage<cute::half_t, cute::half_t, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
|
||||
auto kernel_fptr = gemm_device<
|
||||
decltype(prob_shape), decltype(cta_tiler),
|
||||
cute::half_t, decltype(dA), decltype(sA), decltype(copyA), decltype(s2r_atom_A),
|
||||
cute::half_t, decltype(dB), decltype(sB), decltype(copyB), decltype(s2r_atom_B),
|
||||
cute::half_t, decltype(dC), decltype(sC), decltype(mmaC),
|
||||
decltype(alpha), decltype(beta)>;
|
||||
|
||||
// Set L1 to be SMEM only
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fptr,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA, s2r_atom_A,
|
||||
B, dB, sB, copyB, s2r_atom_B,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
@ -347,13 +489,14 @@ gemm_nt(int m, int n, int k,
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
A, dA, sA, copyA, AutoVectorizingCopy{},
|
||||
B, dB, sB, copyB, AutoVectorizingCopy{},
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
@ -423,13 +566,14 @@ gemm_tn(int m, int n, int k,
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
A, dA, sA, copyA, AutoVectorizingCopy{},
|
||||
B, dB, sB, copyB, AutoVectorizingCopy{},
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
@ -470,6 +614,11 @@ int main(int argc, char** argv)
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::cout << "Using device 0: " << props.name
|
||||
<< " (SM" << props.major * 10 + props.minor
|
||||
<< ", " << props.multiProcessorCount
|
||||
<< ")" << std::endl;
|
||||
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
@ -490,13 +639,13 @@ int main(int argc, char** argv)
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
using TA = cute::half_t;
|
||||
using TB = cute::half_t;
|
||||
using TC = cute::half_t;
|
||||
using TI = cute::half_t;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
TI alpha = static_cast<TI>(1.0f);
|
||||
TI beta = static_cast<TI>(0.0f);
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
|
||||
@ -20,3 +20,35 @@
|
||||
* [04_epilogue_visitor](/examples/python/04_epilogue_visitor.ipynb)
|
||||
|
||||
Shows how to fuse elementwise activation functions to GEMMs via the Python Epilogue Visitor interface
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user