* v3.8 update x

* fix blackwell gg

* doc change

* doc change

* doc change

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Yujia Zhai
2025-03-20 22:52:23 -07:00
committed by GitHub
parent 8c4d1dc47d
commit 62750a2b75
334 changed files with 91517 additions and 2656 deletions

View File

@ -115,4 +115,3 @@ SPDX-License-Identifier: BSD-3-Clause
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -2,3 +2,35 @@
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -165,3 +165,35 @@ Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -41,3 +41,35 @@ We are currently optimizing the following cases:
* Optimizations for memory bound cases.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -207,3 +207,35 @@ With this in mind, this example kernel has the following limitations:
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -26,11 +26,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include_directories(
.
)
set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0)
cutlass_example_add_executable(
63_hopper_gemm_with_weight_prefetch
63_hopper_gemm_with_weight_prefetch.cu
)
TEST_COMMAND_OPTIONS
TEST_PREFETCH_CASE
)
target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .)

View File

@ -74,9 +74,40 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
However, note that the example still runs a single GEMM, and most of the performance improvement
is expected in end to end applications.
## Limitations
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
memory barrier before issuing every single TMA load, and in many cases this will slow down
prefetching to the point of being almost ineffective.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -362,11 +362,11 @@ public:
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z,
/*reserved_named_barriers_*/ 14);
/*id*/ 0);
// Prefetcher warp doesn't arrive on this barrier.
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
/*reserved_named_barriers_*/ 15);
/*id*/ 1);
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
__syncwarp();

View File

@ -62,3 +62,36 @@ procedure is the same, simply modify the following line in the example:
```cpp
using TP = _8;
```
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -84,3 +84,35 @@ GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
```
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
@ -251,93 +251,93 @@ struct Result
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, bits_input);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options<RasterOrderOptions> &options) {
@ -438,14 +438,18 @@ void initialize(const Options<RasterOrderOptions> &options) {
if (IsDFp8 && options.save_amax) {
abs_max_D.resize(cutlass::make_Coord(1));
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_D.sync_device();
reference_abs_max_D.resize(cutlass::make_Coord(1));
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
}
if (IsAuxFp8 && options.save_aux && options.save_amax) {
abs_max_aux.resize(cutlass::make_Coord(1));
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_aux.sync_device();
reference_abs_max_aux.resize(cutlass::make_Coord(1));
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
}
}
@ -517,10 +521,9 @@ bool verify(const Options<RasterOrderOptions> &options) {
// Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto blockscale_m = ceil_div(options.m, get<0>(TileShape{}));
auto blockscale_n = ceil_div(options.n, get<1>(TileShape{}));
auto blockscale_k = ceil_div(options.k, get<2>(TileShape{}));
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
@ -608,29 +611,40 @@ bool verify(const Options<RasterOrderOptions> &options) {
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
bool passed = true;
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (false) {
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
}
#if 0
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
#endif
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
if (options.save_aux) {
tensor_aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (IsAuxFp8 && options.save_amax) {
abs_max_aux.sync_host();
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
}
@ -671,10 +685,9 @@ int run(Options<RasterOrderOptions> &options)
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
// if (!result.passed) {
// exit(-1);
// }
else {
result.passed = true;
}
// Run profiling loop
if (options.iterations > 0)
@ -707,7 +720,7 @@ int run(Options<RasterOrderOptions> &options)
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
return result.passed;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -753,7 +766,9 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
bool passed = run<Gemm>(options);
if (!passed)
return -1;
#endif
return 0;

View File

@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
@ -303,93 +303,93 @@ struct Result
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, bits_input);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
template <typename GroupScaleConfig>
@ -403,11 +403,9 @@ void initialize(const Options<RasterOrderOptions> &options) {
assert(options.n % ScaleGranularityN == 0);
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = ceil_div(options.m, ScaleGranularityM);
auto groupscale_n = ceil_div(options.n, ScaleGranularityN);
auto blockscale_k = ceil_div(options.k, cute::get<2>(TileShape{}));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
@ -582,13 +580,11 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;
auto blockscale_m = ceil_div(options.m, get<0>(TileShape_{}));
auto blockscale_n = ceil_div(options.n, get<1>(TileShape_{}));
auto blockscale_k = ceil_div(options.k, get<2>(TileShape_{}));
auto groupscale_m = ceil_div(options.m, ScaleGranularityM);
auto groupscale_n = ceil_div(options.n, ScaleGranularityN);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
@ -676,8 +672,13 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
bool passed = true;
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
#if 0
std::cout << "tensor_ref_D.host_view() {" << std::endl
@ -690,15 +691,21 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
if (options.save_aux) {
tensor_aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (IsAuxFp8 && options.save_amax) {
abs_max_aux.sync_host();
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
}
@ -716,29 +723,29 @@ int run(Options<RasterOrderOptions> &options)
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
bool skip = false;
if (options.m % ScaleGranularityM != 0) {
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.n % ScaleGranularityN != 0) {
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.k % size<2>(TileShape{}) != 0) {
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
skip = true;
}
if (!skip) std::cout << "Running: " << std::endl;
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
if (skip) return -1;
if (options.m < ScaleGranularityM) {
std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.n < ScaleGranularityN) {
std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl;
skip = true;
}
if (options.k < size<2>(TileShape{})) {
std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
skip = true;
}
if (!skip) std::cout << " Running... " << std::endl;
else return -1;
initialize<GroupScaleConfig>(options);
@ -770,17 +777,17 @@ int run(Options<RasterOrderOptions> &options)
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
if (!result.passed) {
exit(-1);
else {
result.passed = true;
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
if (iter == options.warmup)
timer.start();
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
@ -806,7 +813,7 @@ int run(Options<RasterOrderOptions> &options)
fflush(stdout);
}
return 0;
return result.passed;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -852,27 +859,31 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
bool passed = true;
std::cout << "Basic split-K GEMM kernel" << std::endl;
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
passed &= run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
passed &= run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
passed &= run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
passed &= run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
std::cout << std::endl;
std::cout << std::endl;
std::cout << "StreamK GEMM kernel" << std::endl;
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
passed &= run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
passed &= run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
passed &= run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
passed &= run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
std::cout << std::endl;
if (!passed)
return -1;
#endif
return 0;

View File

@ -46,6 +46,8 @@ struct Options {
int m = 1024, n = 512, k = 1024, l = 1;
RasterOrderOptions raster;
int swizzle;
float epsilon = 0.02f;
float non_zero_floor = 1.f;
// Parses the command line
void parse(int argc, char const **args) {
@ -73,6 +75,8 @@ struct Options {
cmd.get_cmd_line_argument("warmup", warmup);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("verify", verify);
cmd.get_cmd_line_argument("epsilon", epsilon);
cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
@ -113,7 +117,10 @@ struct Options {
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --verify=<bool> Verify the results.\n\n"
<< " --epsilon=<float> The epsilon value for comparing the results.\n\n"
<< " --non-zero-floor=<float> The none zero floor for comparing the results.\n\n";
out
<< "\n\nExamples:\n\n"

View File

@ -221,9 +221,9 @@ void gett_mainloop(
const int N = cute::size<0>(mainloop_params.B.layout());
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
assert(ScaleGranularityM && M % ScaleGranularityM == 0
assert(ScaleGranularityM && M % ScaleGranularityM == 0
&& "ScaleGranularityM must divide M");
assert(ScaleGranularityN && N % ScaleGranularityN == 0
assert(ScaleGranularityN && N % ScaleGranularityN == 0
&& "ScaleGranularityN must divide N");
cute::Tensor blockscale_A = domain_offset(

View File

@ -12,3 +12,35 @@ Note that in Example 55, the argument `--g` is used to determine the block scale
## Upcoming features
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -194,12 +194,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle;
Options():
help(false),
m(8192), n(8192), k(8192),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -217,6 +219,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -231,6 +234,7 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
@ -331,6 +335,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -231,6 +231,7 @@ struct Options {
bool save_amax = true;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
int swizzle = 0;
// Parses the command line
void parse(int argc, char const **args) {
@ -256,6 +257,7 @@ struct Options {
cmd.get_cmd_line_argument("save_aux", save_aux, true);
cmd.get_cmd_line_argument("save_amax", save_amax, true);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -271,6 +273,7 @@ struct Options {
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --scale_a=<f32> Scaling factor for A\n"
<< " --scale_b=<f32> Scaling factor for B\n"
<< " --scale_c=<f32> Scaling factor for C\n"
@ -476,6 +479,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
fusion_args.amax_D_ptr = abs_max_D.device_data();
}
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -28,14 +28,29 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
set(TEST_SWIZZLE_1 --swizzle=1)
set(TEST_SWIZZLE_2 --swizzle=2)
set(TEST_SWIZZLE_5 --swizzle=5)
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
cutlass_example_add_executable(
70_blackwell_fp16_gemm
70_blackwell_fp16_gemm.cu
)
TEST_COMMAND_OPTIONS
TEST_SWIZZLE_1
TEST_SWIZZLE_2
TEST_SWIZZLE_5
TEST_SWIZZLE_5_UNEVEN
)
cutlass_example_add_executable(
70_blackwell_fp8_gemm
70_blackwell_fp8_gemm.cu
TEST_COMMAND_OPTIONS
TEST_SWIZZLE_1
TEST_SWIZZLE_2
TEST_SWIZZLE_5
TEST_SWIZZLE_5_UNEVEN
)
endif()

View File

@ -74,12 +74,14 @@ struct Options {
int m, n, k, l;
float alpha, beta;
int swizzle;
Options():
help(false),
error(false),
m(2048), n(2048), k(2048), l(1),
alpha(1.f), beta(0.f)
alpha(1.f), beta(0.f),
swizzle(0)
{ }
// Parses the command line
@ -97,6 +99,7 @@ struct Options {
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -112,7 +115,8 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n";
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n\n";
return out;
}
@ -352,6 +356,8 @@ struct ExampleRunner {
hw_info
};
arguments.scheduler.max_swizzle_size = options.swizzle;
// See example 48 for details on custom EVT construction
if constexpr (UseCustomEVT) {
arguments.epilogue.thread =

View File

@ -211,12 +211,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -234,6 +236,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -247,7 +250,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -333,7 +337,7 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -344,8 +348,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
@ -387,6 +391,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -177,7 +177,7 @@ using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig<OutputSFVectorSize>;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
@ -240,12 +240,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -263,6 +265,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -276,7 +279,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -362,9 +366,9 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -375,8 +379,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
@ -432,6 +436,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
}
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -212,12 +212,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -235,6 +237,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -248,7 +251,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -334,7 +338,7 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -345,8 +349,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
@ -388,6 +392,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -214,7 +214,8 @@ struct Options {
int iterations;
int m, n, k;
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
int swizzle = 0;
Options():
help(false),
m(4096), n(4096), k(4096),
@ -223,7 +224,8 @@ struct Options {
preferred_cluster_m(4),
preferred_cluster_n(4),
fallback_cluster_m(2),
fallback_cluster_n(1)
fallback_cluster_n(1),
swizzle(0)
{ }
// Parses the command line
@ -245,6 +247,7 @@ struct Options {
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
cmd.get_cmd_line_argument("swizzle", swizzle);
if (!validate_cluster_shape()){
std::cout << "--Invalid cluster shapes" << std::endl;
@ -265,6 +268,7 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
@ -384,7 +388,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -242,6 +242,7 @@ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTil
struct Options {
bool help = false;
bool use_pdl = false;
float alpha = FLT_MAX;
float beta = FLT_MAX;
@ -264,6 +265,9 @@ struct Options {
help = true;
return;
}
if (cmd.check_cmd_line_flag("use_pdl")) {
use_pdl = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
@ -387,7 +391,8 @@ struct Options {
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size\n"
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n";
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
out
<< "\n\nExamples:\n\n"
@ -711,7 +716,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
@ -730,7 +735,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
}
timer.stop();

View File

@ -219,14 +219,14 @@ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
OutputSFVectorSize,
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
>;
using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom;
using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF;
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
// Host-side allocations
std::vector<StrideA> stride_A_host;
@ -305,6 +305,7 @@ struct Options {
bool help = false;
bool verification = true;
bool use_pdl = false;
float alpha = FLT_MAX;
float beta = FLT_MAX;
@ -328,9 +329,12 @@ struct Options {
help = true;
return;
}
if (cmd.check_cmd_line_flag("no-verif")) {
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
if (cmd.check_cmd_line_flag("use_pdl")) {
use_pdl = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
@ -457,7 +461,8 @@ struct Options {
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size\n"
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
<< " --no-verif Do not run (host-side) verification kernels\n";
<< " --no_verif Do not run (host-side) verification kernels\n"
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
out
<< "\n\nExamples:\n\n"
@ -554,9 +559,9 @@ void allocate(const Options &options) {
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
stride_A_host.push_back(stride_A);
stride_B_host.push_back(stride_B);
@ -775,9 +780,9 @@ bool verify(const Options &options) {
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
@ -845,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
cudaDeviceSynchronize();
@ -870,7 +875,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
}
timer.stop();

View File

@ -21,3 +21,35 @@ To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easies
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
It is well-suited for applying masks or activations.
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -0,0 +1,546 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA.
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
This kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79a_blackwell_geforce_nvfp4_bf16_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,593 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
The kernel outputs quantized fp4 values with scale factors that will be the input of another GEMM.
This kernel is optimized for the GeForce RTX 50 series GPUs.
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFD matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
constexpr int InputSFVectorSize = 16;
constexpr int OutputSFVectorSize = InputSFVectorSize;
// D = alpha * acc + beta * C
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementSFD, LayoutSFDTag,
ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong // Ping-pong kernel schedule policy.
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
// Matrix-wide normalization constant
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79b_blackwell_geforce_nvfp4_nvfp4_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
if constexpr (IsBlockScaleSupported) {
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
}
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
auto tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<OutputSFVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,546 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA.
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float6_t<cutlass::float_e3m2_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79c_blackwell_geforce_mixed_mxfp8_bf16_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,47 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
cutlass_example_add_executable(
79a_blackwell_geforce_nvfp4_bf16_gemm
79a_blackwell_geforce_nvfp4_bf16_gemm.cu
)
cutlass_example_add_executable(
79b_blackwell_geforce_nvfp4_nvfp4_gemm
79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu
)
cutlass_example_add_executable(
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu
)
endif()

View File

@ -216,7 +216,7 @@ struct Options {
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
<< "$ " << "112_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}

View File

@ -157,6 +157,7 @@ foreach(EXAMPLE
76_blackwell_conv
77_blackwell_fmha
78_blackwell_emulated_bf16x9_gemm
79_blackwell_geforce_gemm
81_blackwell_gemm_blockwise
)

View File

@ -282,6 +282,10 @@
Blackwell SM100 FastFP32 (using BF16 to emulate SGEMM) kernel
* [79_blackwell_geforce_gemm](79_blackwell_geforce_gemm/)
Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores
# CuTe - Programming Examples
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
@ -291,3 +295,35 @@ Additionally, CuTe's core layout and layout algebra have their own test cases wi
# Python Interface Examples
Examples leveraging CUTLASS's [Python interface](../python/README.md) are located in [cutlass/examples/python](python/).
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -58,7 +58,7 @@ struct IndexedGather
operator()(I i) const { return indices_[i]; }
CUTE_HOST_DEVICE friend
void
void
print(IndexedGather const &s) {
cute::print("Indexed");
}
@ -80,7 +80,7 @@ struct StridedGather
operator()(I i) const { return i * stride_; }
CUTE_HOST_DEVICE friend
void
void
print(StridedGather const &s) {
cute::print("Strided{");
print(s.stride_);
@ -153,7 +153,7 @@ make_custom_stride_layout(Stride const &stride, Func&& func)
/// Helper function to optionally create a gather tensor
template<class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE
auto
auto
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
{
if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
@ -180,7 +180,7 @@ upcast(Shape const& shape, Stride const& stride)
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
} else if constexpr (is_scaled_basis<Stride>::value) {
if constexpr (Stride::mode() == I) {
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
} else {
return make_layout(shape, stride);
}

View File

@ -27,34 +27,31 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(hopper)
add_subdirectory(blackwell)
cutlass_example_add_executable(
sgemm_1
cute_tutorial_sgemm_1
sgemm_1.cu
)
cutlass_example_add_executable(
sgemm_2
cute_tutorial_sgemm_2
sgemm_2.cu
)
cutlass_example_add_executable(
sgemm_sm70
cute_tutorial_sgemm_sm70
sgemm_sm70.cu
)
cutlass_example_add_executable(
sgemm_sm80
cute_tutorial_sgemm_sm80
sgemm_sm80.cu
)
cutlass_example_add_executable(
tiled_copy
cute_tutorial_tiled_copy
tiled_copy.cu
)
cutlass_example_add_executable(
wgmma_sm90
wgmma_sm90.cu
)

View File

@ -0,0 +1,592 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp>
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 01: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// The goal of this tutorial is to show the CuTe interface for tcgen05.mma and tcgen05.ld operations.
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// This GEMM kernel performs the following steps:
// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) for one MmaTile
// using auto-vectorizing copy operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
//
// SM100 tcgen05.mma instructions operate as follows:
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A and B matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: gmem_ptr[16b](GMEM_ADDR_A) o (512,256):(256,_1)
print("mB:\t"); print(mB); print("\n"); // mB: gmem_ptr[16b](GMEM_ADDR_B) o (1024,256):(256,_1)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile) o (_128,_64,4):(256,_1,_64)
print("gB:\t"); print(gB); print("\n"); // gB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile) o (_256,_64,4):(_1,256,16384)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
// Note: Partitioned tensors use tXgY naming convention:
// tXgY -> The partitioning pattern tX applied to tensor gY
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile + offset_for_mma) o ((_128,_16),_1,_4,4):((256,_1),_0,_16,_64)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile + offset_for_mma) o ((_256,_16),_1,_4,4):((_1,256),_0,4096,16384)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// Barrier Initialization
uint32_t elect_one_thr = cute::elect_one_sync();
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
// Barriers in SMEM initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
__syncthreads(); // Make sure all threads observe barrier initialization.
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// Using auto-vectorized copy operation:
// - Utilizes 128 threads for parallel data transfer
// - Copy operations are distributed efficiently across all threads
// - CuTe can automatically determine optimal vector width
cooperative_copy<128>(threadIdx.x, tCgA(_,_,_,k_tile), tCsA); // Load MmaTile_M x MmaTile_K A tile
cooperative_copy<128>(threadIdx.x, tCgB(_,_,_,k_tile), tCsB); // Load MmaTile_N x MmaTile_K B tile
// Step 2b: Execute the MMAs for this tile
// Wait for loads to SMEM to complete with __syncthreads()
__syncthreads();
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
// Load C tensor GMEM -> RMEM
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
// Load TMEM -> RMEM
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
axpby(alpha, tDrAcc, beta, tDrC);
// Store RMEM -> GMEM
copy(tDrC, tDgD);
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
128, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _1:_0
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
// The cluster shape and layout
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA), decltype(mB), decltype(mC), decltype(mD),
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA, mB, mC, mD,
mma_tiler, tiled_mma, cluster_shape,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -0,0 +1,671 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp>
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 02: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// This GEMM kernel extends 01_mma_sm100.cu by adding Tensor Memory Access (TMA) and performs the following steps:
// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
//
// SM100 tcgen05.mma instructions operate as follows:
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A and B matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
class TmaAtomA, class TmaAtomB,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
// Note: Partitioned tensors use tXgY naming convention:
// tXgY -> The partitioning pattern tX applied to tensor gY
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// TMA Setup
//
// These are TMA partitionings, which have a dedicated custom partitioner.
// The Int<0>, Layout<_1> indicates that the TMAs are not multicasted.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
Int<0>{}, Layout<_1>{},
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
Int<0>{}, Layout<_1>{},
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
// Calculate total bytes that TMA will transfer each tile to track completion
int tma_transaction_bytes = sizeof(make_tensor_like(tAsA))
+ sizeof(make_tensor_like(tBsB));
if (thread0()) {
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
printf("TmaBytes: %d\n", tma_transaction_bytes);
} __syncthreads();
// Barrier Initialization
uint32_t elect_one_thr = cute::elect_one_sync();
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
// Barriers in SMEM initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1);
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
__syncthreads(); // Make sure all threads observe barrier initialization.
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// TMA Load Operations:
// - Execute asynchronous TMA loads with single thread
// - Set transaction bytes and execute with barrier
if (elect_one_warp && elect_one_thr) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
copy(tma_atom_A.with(shared_storage.tma_barrier), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
copy(tma_atom_B.with(shared_storage.tma_barrier), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
}
// Step 2b: Execute the MMAs for this tile
// Wait for TMA loads to SMEM to complete
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
tma_barrier_phase_bit ^= 1;
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
// Load C tensor GMEM -> RMEM
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
// Load TMEM -> RMEM
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
axpby(alpha, tDrAcc, beta, tDrC);
// Store RMEM -> GMEM
copy(tDrC, tDgD);
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
128, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _1:_0
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
//
// TMA Descriptor Creation (Host Side)
//
// The cluster shape and layout
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
// Create TMA descriptors for A and B matrices
Copy_Atom tma_atom_A = make_tma_atom(
SM90_TMA_LOAD{}, // TMA Load Op
mA, // Source GMEM tensor
sA_layout, // Destination SMEM layout
select<0,2>(mma_tiler) // MK Tiler for TMA operation
);
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
// tma_atom_A: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_8192):(_0,_1)
// ValLayoutDst: (_1,_8192):(_0,_1)
// ValLayoutRef: (_1,_8192):(_0,_1)
// ValueType: 16b
Copy_Atom tma_atom_B = make_tma_atom(
SM90_TMA_LOAD{}, // TMA Load Op
mB, // Source GMEM tensor
sB_layout, // Destination SMEM layout
select<1,2>(mma_tiler) // NK Tiler for TMA operation
);
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
// tma_atom_B: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_16384):(_0,_1)
// ValLayoutDst: (_1,_16384):(_0,_1)
// ValLayoutRef: (_1,_16384):(_0,_1)
// ValueType: 16b
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA_tma, mB_tma, mC, mD,
mma_tiler, tiled_mma, cluster_shape,
tma_atom_A, tma_atom_B,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -0,0 +1,711 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp>
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 03: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// Key extensions from tutorial 02_mma_tma_sm100.cu:
// 1. Introduce ClusterShape for coordinated execution across thread blocks
// 2. Introduce TMA multicast
// 3. Enhanced TMA <-> MMA synchronization for cluster-wide operations
//
// This GEMM kernel will perform the following steps:
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA load operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
//
// SM100 tcgen05.mma instructions operate as follows:
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A and B matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
class TmaAtomA, class TmaAtomB,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// TMA Setup
//
// These are TMA partitionings, which have a dedicated custom partitioner.
// In this example, the TMA multicasts the loads across multiple CTAs.
// Loads of A are multicasted along the N dimension of the cluster_shape_MNK and
// Loads of B are multicasted along the M dimension of the cluster_shape_MNK.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
// Each CTA with the same m-coord will load a portion of A
// Each CTA with the same n-coord will load a portion of B
// Multicast behavior for CTA 1,2 in the cluster
// A multicast B multicast
// 0 1 2 3 0 1 2 3
// 0 - - - - 0 - - X -
// 1 X X X X 1 - - X -
// 2 - - - - 2 - - X -
// 3 - - - - 3 - - X -
// tma_multicast_mask_A = 0x2222
// tma_multicast_mask_B = 0x0F00
// mma_multicast_mask_C = 0x2F22
// Construct the CTA-in-Cluster coordinate for multicasting
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
// Project the cluster_layout for tma_A along the N-modes
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
// Project the cluster_layout for tma_B along the M-modes
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Calculate total bytes that TMA will transfer each tile to track completion
int tma_transaction_bytes = sizeof(make_tensor_like(tAsA))
+ sizeof(make_tensor_like(tBsB));
if (thread0()) {
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
} __syncthreads();
// Barrier Initialization
uint32_t elect_one_thr = cute::elect_one_sync();
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
// Barriers in SMEM initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
cute::cluster_sync(); // Make sure all threads across all CTAs in Cluster observe barrier initialization.
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// TMA Load Operations:
// - Execute asynchronous TMA loads with single thread
// - Set transaction bytes and execute with barrier
if (elect_one_warp && elect_one_thr) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
}
// Step 2b: Execute the MMAs for this tile
// Wait for TMA loads to SMEM to complete
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
tma_barrier_phase_bit ^= 1;
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive_multicast(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
// Load C tensor GMEM -> RMEM
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
// Load TMEM -> RMEM
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
axpby(alpha, tDrAcc, beta, tDrC);
// Store RMEM -> GMEM
copy(tDrC, tDgD);
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
128, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _1:_0
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
//
// TMA Descriptor Creation (Host Side)
//
// The cluster shape and layout
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
Copy_Atom tma_atom_A = make_tma_atom(
SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast
mA, // Source GMEM tensor
sA_layout, // Destination SMEM layout
select<0,2>(mma_tiler), // MK Tiler for TMA operation
size<2>(cluster_layout_vmnk) // The number of CTAs in the N-mode for multicasting
);
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
// tma_atom_A: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_8192):(_0,_1)
// ValLayoutDst: (_1,_8192):(_0,_1)
// ValLayoutRef: (_1,_8192):(_0,_1)
// ValueType: 16b
Copy_Atom tma_atom_B = make_tma_atom(
SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast
mB, // Source GMEM tensor
sB_layout, // Destination SMEM layout
select<1,2>(mma_tiler), // NK Tiler for TMA operation
size<1>(cluster_layout_vmnk) // The number of CTAs in the M-mode for multicasting
);
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
// tma_atom_B: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_16384):(_0,_1)
// ValLayoutDst: (_1,_16384):(_0,_1)
// ValLayoutRef: (_1,_16384):(_0,_1)
// ValueType: 16b
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA_tma, mB_tma, mC, mD,
mma_tiler, tiled_mma, cluster_shape,
tma_atom_A, tma_atom_B,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -0,0 +1,716 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp>
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 04: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// Key extensions to tutorial 03_mma_tma_multicast_sm100.cu:
// 1. Introduce 2SM tcgen05.mma instructions
// 2. Introduce 2SM TMA instructions
// 3. Demonstrate TMA multicast pattern specialized for 2SM instructions for loading A and B matrices
//
// This GEMM kernel will perform the following steps:
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
//
// SM100 2SM tcgen05.mma instructions operate as follows:
// - Mma is launched by only one SM
// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction.
// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA.
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A and B matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
class TmaAtomA, class TmaAtomB,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// TMA Setup
//
// These are TMA partitionings, which have a dedicated custom partitioner.
// In this example, the TMA multicasts the loads across multiple CTAs.
// Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and
// Loads of B are multicasted along the M dimension of the cluster_shape_VMNK.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
// Each CTA with the same m-coord will load a portion of A
// Each CTA with the same n-coord will load a portion of B
// Computation of the multicast masks must take into account the Peer CTA for TMA.2SM
// Construct the CTA-in-Cluster coordinate for multicasting
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
// Project the cluster_layout for tma_A along the N-modes
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
// Project the cluster_layout for tma_B along the M-modes
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM
int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA))
+ size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB));
if (thread0()) {
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
} __syncthreads();
// Barrier Initialization
auto elect_one_thr = cute::elect_one_sync();
auto elect_one_warp = (threadIdx.x / 32 == 0);
auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{};
// Barriers in SMEM should be initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc.
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// TMA Load Operations:
// - Execute asynchronous TMA loads with single thread
// - Both peer and leader CTAs initiate TMA loads
// - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs.
// - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA.
// - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast.
if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread
if (elect_one_cta) { // Only the leader CTA waits for TMA transactions
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads
}
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
}
// Step 2b: Execute the MMAs for this tile
if (elect_one_cta) {
// Wait for TMA loads to complete on leader CTAs
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
tma_barrier_phase_bit ^= 1;
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
}
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
// Load C tensor GMEM -> RMEM
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
// Load TMEM -> RMEM
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
axpby(alpha, tDrAcc, beta, tDrC);
// Store RMEM -> GMEM
copy(tDrC, tDgD);
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
256, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _2:_1
// Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.)
// LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
//
// TMA Descriptor Creation (Host Side)
//
// The cluster shape and layout
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
// SM100 interface for creating TMA loads.
Copy_Atom tma_atom_A = make_tma_atom_A_sm100(
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
mA, // Source GMEM tensor
sA_layout, // Destination SMEM layout
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
// tma_atom_A: Copy_Atom
// ThrID: _2:_1
// ValLayoutSrc: (_2,_8192):(_8192,_1)
// ValLayoutDst: (_2,_8192):(_8192,_1)
// ValLayoutRef: (_2,_8192):(_8192,_1)
// ValueType: 16b
// SM100 interface for creating TMA loads.
Copy_Atom tma_atom_B = make_tma_atom_B_sm100(
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
mB, // Source GMEM tensor
sB_layout, // Destination SMEM layout
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
// tma_atom_B: Copy_Atom
// ThrID: _2:_1
// ValLayoutSrc: (_2,_8192):(_8192,_1)
// ValLayoutDst: (_2,_8192):(_8192,_1)
// ValLayoutRef: (_2,_8192):(_8192,_1)
// ValueType: 16b
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA_tma), decltype(mB_tma), decltype(mC), decltype(mD),
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
decltype(tma_atom_A), decltype(tma_atom_B), // Includes the TMA descriptor.
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA_tma, mB_tma, mC, mD,
mma_tiler, tiled_mma, cluster_shape,
tma_atom_A, tma_atom_B,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -0,0 +1,825 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp>
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 05: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// Key extensions to tutorial 04_mma_tma_2sm_sm100.cu:
// 1. Demonstrate using TMA instructions in the epilogue
//
// This GEMM kernel will perform the following steps:
// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to shared memory (SMEM) with TMA.
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from shared memory (SMEM) to global memory (GMEM) with TMA.
//
// SM100 2SM tcgen05.mma instructions operate as follows:
// - Mma is launched by only one SM
// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction.
// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA.
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A, B, C, and D matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class TypeC, // Tensor C data type
class TypeD, // Tensor D data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout, // (MmaB, NumMma_N, NumMma_K, ...)
class CSmemLayout, // EpiTile_M, EpiTile_N
class DSmemLayout> // EpiTile_M, EpiTile_N
struct SharedStorage
{
alignas(128) union {
alignas(128) struct {
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
} mainloop;
alignas(128) cute::ArrayEngine<TypeC, cute::cosize_v<CSmemLayout>> C;
alignas(128) cute::ArrayEngine<TypeD, cute::cosize_v<DSmemLayout>> D;
} tensors;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(tensors.mainloop.A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(tensors.mainloop.B.begin()), BSmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sC() { return make_tensor(make_smem_ptr(tensors.C.begin()), CSmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sD() { return make_tensor(make_smem_ptr(tensors.D.begin()), DSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class EpiTiler_MN, class TiledMMA, class ClusterShape_MNK,
class TmaAtomA, class TmaAtomB, class TmaAtomC, class TmaAtomD,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
EpiTiler_MN epi_tiler_mn, // <EpiTile_M, EpiTile_N>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,
CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B,
CUTE_GRID_CONSTANT TmaAtomC const tma_atom_C,
CUTE_GRID_CONSTANT TmaAtomD const tma_atom_D,
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0)
print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0)
print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// TMA Setup
//
// These are TMA partitionings, which have a dedicated custom partitioner.
// In this example, the TMA multicasts the loads across multiple CTAs.
// Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and
// Loads of B are multicasted along the M dimension of the cluster_shape_VMNK.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK.
// For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor
// into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK.
// Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
// Each CTA with the same m-coord will load a portion of A
// Each CTA with the same n-coord will load a portion of B
// Computation of the multicast masks must take into account the Peer CTA for TMA.2SM
// Construct the CTA-in-Cluster coordinate for multicasting
auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster()));
// Project the cluster_layout for tma_A along the N-modes
auto [tAgA, tAsA] = tma_partition(tma_atom_A,
get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster
make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
// Project the cluster_layout for tma_B along the M-modes
auto [tBgB, tBsB] = tma_partition(tma_atom_B,
get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster
make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
// Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C
uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) |
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);
// Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM
int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA))
+ size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB));
if (thread0()) {
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c);
} __syncthreads();
// Barrier Initialization
auto elect_one_thr = cute::elect_one_sync();
auto elect_one_warp = (threadIdx.x / 32 == 0);
auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{};
// Barriers in SMEM should be initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
// The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices)
int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1;
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants);
cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc.
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// TMA Load Operations:
// - Execute asynchronous TMA loads with single thread
// - Both peer and leader CTAs initiate TMA loads
// - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs.
// - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA.
// - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast.
if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread
if (elect_one_cta) { // Only the leader CTA waits for TMA transactions
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads
}
copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile
copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile
}
// Step 2b: Execute the MMAs for this tile
if (elect_one_cta) {
// Wait for TMA loads to complete on leader CTAs
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
tma_barrier_phase_bit ^= 1;
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask.
}
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Apply rank-2 epilogue tiler to rank-2 MMA_V mode
auto epi_tiler_v = make_tile(epi_tiler_mn); // (EpiTile)
Tensor tAcc_epi = zipped_divide(tCtAcc, epi_tiler_v); // (EpiTile,NumTiles)
Tensor gC_epi = zipped_divide(tCgC, epi_tiler_v); // (EpiTile,NumTiles)
Tensor gD_epi = zipped_divide(tCgD, epi_tiler_v); // (EpiTile,NumTiles)
// Construct corresponding SMEM tensors
Tensor sC_epi = shared_storage.tensor_sC(); // (EpiTile)
Tensor sD_epi = shared_storage.tensor_sD(); // (EpiTile)
// Partition for TMA
auto [tGS_gC, tGS_sC] = tma_partition(tma_atom_C, sC_epi, gC_epi); // (GMEM -> SMEM)
auto [tSG_gD, tSG_sD] = tma_partition(tma_atom_D, sD_epi, gD_epi); // (SMEM -> GMEM)
// Reset transaction bytes for C load
tma_transaction_bytes = sizeof(make_tensor_like(tGS_sC));
// Partition for TMEM accumulators load (TMEM -> RMEM)
TiledCopy t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tAcc_epi(_,_0{}));
ThrCopy thr_t2r = t2r_copy.get_slice(threadIdx.x);
Tensor tTR_tAcc = thr_t2r.partition_S(tAcc_epi); // (TmemCpy,NumTmemCpy,NumTiles)
Tensor tTR_sC = thr_t2r.partition_D(sC_epi); // (TmemCpy,NumTmemCpy)
Tensor tTR_sD = thr_t2r.partition_D(sD_epi); // (TmemCpy,NumTmemCpy)
// Allocate register tensors
Tensor tTR_rC = make_tensor_like(tTR_sC); // (TmemCpy,NumTmemCpy)
Tensor tTR_rD = make_fragment_like(tTR_sD); // (TmemCpy,NumTmemCpy)
// Loop over the epilogue tiles
CUTE_UNROLL
for (int epi_tile_idx = 0; epi_tile_idx < size<2>(tTR_tAcc); ++epi_tile_idx) {
// TMA Load C: GMEM -> SMEM
if (elect_one_warp && elect_one_thr) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes);
copy(tma_atom_C.with(shared_storage.tma_barrier, 0 /*no multicast*/), tGS_gC(_,epi_tile_idx), tGS_sC);
}
// All threads wait for C TMA load to complete
cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);
tma_barrier_phase_bit ^= 1;
// Load C: SMEM -> RMEM
copy_aligned(tTR_sC, tTR_rC);
// Load Acc: TMEM -> RMEM
copy(t2r_copy, tTR_tAcc(_,_,epi_tile_idx), tTR_rD);
// Compute D = beta * C + alpha * (A*B)
axpby(beta, tTR_rC, alpha, tTR_rD);
// Store D: RMEM -> SMEM
__syncthreads(); // Ensure C loads are finished before reusing smem (unnecessary if smem layouts match)
copy_aligned(tTR_rD, tTR_sD);
// TMA Store D: SMEM -> GMEM
tma_store_fence(); // Ensure D smem stores are visible to TMA
__syncthreads(); // Ensure all threads have issued fence
if (elect_one_warp && elect_one_thr) {
copy(tma_atom_D, tSG_sD, tSG_gD(_,epi_tile_idx));
tma_store_arrive(); // issuing thread commits D TMA store
tma_store_wait<0>(); // issuing thread waits for D TMA store to complete
}
__syncthreads(); // All threads sync with issuing thread
}
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
256, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _2:_1
// Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.)
// LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
//
// Epilogue parameters
//
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_N) to post-partitioned ((MmaM,MmaN), NumMma_M, NumMma_N)
auto mma_shape_C = partition_shape_C(tiled_mma, make_shape(size<0>(mma_tiler), size<1>(mma_tiler)));
// For TMA epilogue performance it may be beneficial to iterate over the output in smaller tiles than the MMA tile
auto epi_tiler = make_tile(size<0,0>(mma_shape_C), size<0,1>(mma_shape_C) / Int<4>{}); // 4 TMA copies per CTA per MMA tile
// SMEM layouts for C and D should match the epilogue tile
auto sC_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom<TypeC>{}, // MMA K-major is equivalent to epilogue N-major
make_shape(size<0>(epi_tiler), size<1>(epi_tiler)));
auto sC_layout = group<0,2>(sC_layout_mn); // Group modes for tma_partition
auto sD_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom<TypeD>{}, // MMA K-major is equivalent to epilogue N-major
make_shape(size<0>(epi_tiler), size<1>(epi_tiler)));
auto sD_layout = group<0,2>(sD_layout_mn); // Group modes for tma_partition
print("sC_layout:\t"); print(sC_layout); print("\n"); // sC_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096))
print("sD_layout:\t"); print(sD_layout); print("\n"); // sD_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096))
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, TypeC, TypeD,
decltype(sA_layout), decltype(sB_layout),
decltype(sC_layout), decltype(sD_layout)>;
//
// TMA Descriptor Creation (Host Side)
//
// The cluster shape and layout
auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
// SM100 interface for creating TMA loads.
Copy_Atom tma_atom_A = make_tma_atom_A_sm100(
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
mA, // Source GMEM tensor
sA_layout, // Destination SMEM layout
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K)
print("tma_atom_A:\t"); print(tma_atom_A); print("\n");
// tma_atom_A: Copy_Atom
// ThrID: _2:_1
// ValLayoutSrc: (_2,_8192):(_8192,_1)
// ValLayoutDst: (_2,_8192):(_8192,_1)
// ValLayoutRef: (_2,_8192):(_8192,_1)
// ValueType: 16b
// SM100 interface for creating TMA loads.
Copy_Atom tma_atom_B = make_tma_atom_B_sm100(
SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction.
mB, // Source GMEM tensor
sB_layout, // Destination SMEM layout
mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes.
tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning.
cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed.
// We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode.
Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K)
print("tma_atom_B:\t"); print(tma_atom_B); print("\n");
// tma_atom_B: Copy_Atom
// ThrID: _2:_1
// ValLayoutSrc: (_2,_8192):(_8192,_1)
// ValLayoutDst: (_2,_8192):(_8192,_1)
// ValLayoutRef: (_2,_8192):(_8192,_1)
// ValueType: 16b
Copy_Atom tma_atom_C = make_tma_atom(
SM90_TMA_LOAD{}, // TMA load operation
mC, // Source GMEM tensor
sC_layout, // Destination SMEM layout
epi_tiler); // MN Tiler for epilogue
Tensor mC_tma = tma_atom_C.get_tma_tensor(shape(mC)); // (Gemm_M, Gemm_N)
print("tma_atom_C:\t"); print(tma_atom_C); print("\n");
// tma_atom_C: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_4096):(_0,_1)
// ValLayoutDst: (_1,_4096):(_0,_1)
// ValLayoutRef: (_1,_4096):(_0,_1)
// ValueType: 32b
Copy_Atom tma_atom_D = make_tma_atom(
SM90_TMA_STORE{}, // TMA store operation
mD, // Destination GMEM tensor
sD_layout, // Source SMEM layout
epi_tiler); // MN Tiler for epilogue
Tensor mD_tma = tma_atom_D.get_tma_tensor(shape(mD)); // (Gemm_M, Gemm_N)
print("tma_atom_D:\t"); print(tma_atom_D); print("\n");
// tma_atom_D: Copy_Atom
// ThrID: _1:_0
// ValLayoutSrc: (_1,_4096):(_0,_1)
// ValLayoutDst: (_1,_4096):(_0,_1)
// ValLayoutRef: (_1,_4096):(_0,_1)
// ValueType: 32b
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA_tma), decltype(mB_tma), decltype(mC_tma), decltype(mD_tma),
decltype(mma_tiler), decltype(epi_tiler), decltype(tiled_mma), decltype(cluster_shape),
decltype(tma_atom_A), decltype(tma_atom_B), decltype(tma_atom_C), decltype(tma_atom_D), // Includes the TMA descriptor.
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA_tma, mB_tma, mC_tma, mD_tma,
mma_tiler, epi_tiler, tiled_mma, cluster_shape,
tma_atom_A, tma_atom_B, tma_atom_C, tma_atom_D,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -0,0 +1,54 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
cute_tutorial_01_mma_sm100
01_mma_sm100.cu
)
cutlass_example_add_executable(
cute_tutorial_02_mma_tma_sm100
02_mma_tma_sm100.cu
)
cutlass_example_add_executable(
cute_tutorial_03_mma_tma_multicast_sm100
03_mma_tma_multicast_sm100.cu
)
cutlass_example_add_executable(
cute_tutorial_04_mma_tma_2sm_sm100
04_mma_tma_2sm_sm100.cu
)
cutlass_example_add_executable(
cute_tutorial_05_mma_tma_epi_sm100
05_mma_tma_epi_sm100.cu
)
endif()

View File

@ -0,0 +1,105 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/copy_sm90_desc.hpp>
template <class AccType,
class TensorA, class TensorB,
class TensorC, class TensorD,
class Alpha, class Beta>
void
reference_gemm(TensorA const& tensor_A, TensorB const& tensor_B,
TensorC const& tensor_C, TensorD & tensor_D,
Alpha alpha, Beta beta)
{
using namespace cute;
for (int m = 0; m < size<0>(tensor_D); ++m) {
for (int n = 0; n < size<1>(tensor_D); ++n) {
AccType c = AccType(0.f);
for (int k = 0; k < size<1>(tensor_A); ++k) {
c += tensor_A(m,k) * tensor_B(n,k);
}
tensor_D(m,n) = alpha * c + beta * tensor_C(m,n);
}
}
}
template <class TensorA, class TensorB,
class TensorC, class TensorD,
class RefTensorD>
bool
compare_results(TensorA const& tensor_A, TensorB const& tensor_B,
TensorC const& tensor_C, TensorD const& tensor_D,
RefTensorD const& ref_tensor_D,
bool print_diff = false)
{
using namespace cute;
auto norm_A = matrix_inf_norm(tensor_A);
auto norm_B = matrix_inf_norm(tensor_B);
auto norm_C = matrix_inf_norm(tensor_C);
auto norm_D = matrix_inf_norm(tensor_D);
auto norm_ref_D = matrix_inf_norm(ref_tensor_D);
auto norm_diff = matrix_diff_inf_norm(tensor_D, ref_tensor_D);
if (print_diff) {
for (int m = 0; m < size<0>(tensor_D); ++m) {
for (int n = 0; n < size<1>(tensor_D); ++n) {
std::cout << m << "," << n << " : " << tensor_D(m,n) << " vs. " << ref_tensor_D(m,n) << std::endl;
}
}
}
std::cout << "norm (A) : " << norm_A.inf_norm << std::endl;
std::cout << "norm (B) : " << norm_B.inf_norm << std::endl;
std::cout << "norm (C) : " << norm_C.inf_norm << std::endl;
std::cout << "norm (D) : " << norm_D.inf_norm << std::endl;
std::cout << "norm (ref_D) : " << norm_ref_D.inf_norm << std::endl;
std::cout << "norm (D-ref_D) : " << norm_diff.inf_norm << std::endl;
return (!norm_A.found_nan) && (!norm_B.found_nan) &&
(!norm_C.found_nan) && (!norm_D.found_nan) && (!norm_ref_D.found_nan) && // There are no NaNs
(norm_A.inf_norm > 0.0) && (norm_B.inf_norm > 0.0) &&
(norm_C.inf_norm > 0.0) && (norm_D.inf_norm > 0.0) && (norm_ref_D.inf_norm > 0.0) && // Values in tensors aren't zeros
(norm_diff.inf_norm <= 0.0); // Diff (ref_D-D) == 0
}
template <class Tensor>
void
initialize_tensor(Tensor& tensor, cute::tuple<int, int> value_range = {-2, 2})
{
using DataType = typename Tensor::element_type;
auto [min, max] = value_range;
for (int i = 0; i < cute::size(tensor); i++) {
tensor(i) = DataType(int((max-min)*(rand() / double(RAND_MAX)) + min));
}
}

View File

@ -0,0 +1,38 @@
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
cute_tutorial_wgmma_sm90
wgmma_sm90.cu
)
cutlass_example_add_executable(
cute_tutorial_wgmma_tma_sm90
wgmma_tma_sm90.cu
)

View File

@ -0,0 +1,611 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/cluster_launch.hpp"
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
using namespace cute;
template <class ElementA,
class ElementB,
class SmemLayoutA, // (M,K,P)
class SmemLayoutB> // (N,K,P)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<ElementA, cosize_v<SmemLayoutA>> A;
alignas(128) cute::ArrayEngine<ElementB, cosize_v<SmemLayoutB>> B;
};
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, TiledMma mma,
Alpha alpha, Beta beta)
{
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory tensors
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), ASmemLayout{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), BSmemLayout{}); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles across the threads
//
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor sA_ = as_position_independent_swizzle_tensor(sA);
Tensor tAsA = thr_copy_a.partition_D(sA_); // (CPY,CPY_M,CPY_K,PIPE)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor sB_ = as_position_independent_swizzle_tensor(sB);
Tensor tBsB = thr_copy_b.partition_D(sB_); // (CPY,CPY_N,CPY_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
//
// PREFETCH
//
// auto K_PIPE_MAX = size<3>(tAsA);
// // Total count of tiles
// int k_tile_count = size<3>(tAgA);
// // Current tile index in gmem to read from
// int k_tile_next = 0;
// // Start async loads for all pipes but the last
// CUTE_UNROLL
// for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) {
// copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe));
// copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe));
// cp_async_fence();
// --k_tile_count;
// if (k_tile_count > 0) { ++k_tile_next; }
// }
//
// Define A/B partitioning and C accumulators
//
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// Total number of k-tiles
auto K_TILE_MAX = size<3>(tAgA);
// Number of pipelined k-tiles in smem
auto K_PIPE_MAX = size<3>(tAsA);
//
// PREFETCH
//
// Prefetch all but the last
CUTE_UNROLL
for (int k = 0; k < K_PIPE_MAX-1; ++k)
{
copy(copy_a, tAgA(_,_,_,k), tAsA(_,_,_,k));
copy(copy_b, tBgB(_,_,_,k), tBsB(_,_,_,k));
cp_async_fence();
}
// Clear the accumulators
clear(tCrC);
__syncthreads();
//
// PIPELINED MAIN LOOP
//
// Current pipe to read from
int k_pipe_read = 0;
// Current pipe to write to
int k_pipe_write = K_PIPE_MAX-1;
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
int k_tile_next = k_tile + (K_PIPE_MAX-1);
k_tile_next = (k_tile_next >= K_TILE_MAX) ? K_TILE_MAX-1 : k_tile_next;
//
// Copy gmem to smem for k_tile_write
//
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe_write));
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe_write));
cp_async_fence();
// Advance k_pipe_write
++k_pipe_write;
k_pipe_write = (k_pipe_write == K_PIPE_MAX) ? 0 : k_pipe_write;
//
// Compute on k_tile
//
// Wait on all cp.async -- optimize by pipelining to overlap GMEM reads
cp_async_wait<0>();
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(mma, tCrA(_,_,_,k_pipe_read), tCrB(_,_,_,k_pipe_read), tCrC);
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<0>();
warpgroup_fence_operand(tCrC);
// Advance k_pipe_read
++k_pipe_read;
k_pipe_read = (k_pipe_read == K_PIPE_MAX) ? 0 : k_pipe_read;
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
Layout<Shape<_16,_8>>{}, // Thr layout 32x4 m-major
Layout<Shape< _8,_1>>{});// Val layout 8x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
Layout<Shape<_16,_8>>{}, // Thr layout 32x4 n-major
Layout<Shape< _8,_1>>{});// Val layout 8x1 n-major
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
//
// Setup and Launch
//
// Launch parameter setup
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(1, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
int smemBytes = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
auto* kernel_ptr = &gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(dA), decltype(sA), decltype(copyA),
TB, decltype(dB), decltype(sB), decltype(copyB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>;
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
// Kernel Launch
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
Layout<Shape< _1,_8>>{}); // Val layout 1x8
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
Layout<Shape< _1,_8>>{}); // Val layout 1x8
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::K,GMMA::Major::K>{});
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
//
// Setup and Launch
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(1, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(dA), decltype(sA), decltype(copyA),
TB, decltype(dB), decltype(sB), decltype(copyB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 8) {
std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = cute::half_t;
using TB = cute::half_t;
using TC = cute::half_t;
using TI = cute::half_t;
TI alpha = TI(1.0f);
TI beta = TI(0.0f);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
// Initialize the tensors
for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < m*n; ++j) h_C[j] = TC(0);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#else
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -55,8 +55,8 @@ template <class ElementA,
class SmemLayoutB> // (N,K,P)
struct SharedStorage
{
array_aligned<ElementA, cosize_v<SmemLayoutA>> smem_A;
array_aligned<ElementB, cosize_v<SmemLayoutB>> smem_B;
alignas(128) cute::ArrayEngine<ElementA, cosize_v<SmemLayoutA>> A;
alignas(128) cute::ArrayEngine<ElementB, cosize_v<SmemLayoutB>> B;
uint64_t tma_barrier[size<2>(SmemLayoutA{})];
uint64_t mma_barrier[size<2>(SmemLayoutA{})];
@ -110,8 +110,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, SmemLayoutA, SmemLayoutB>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles
@ -132,8 +132,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE)
// The TMA is responsible for copying everything in mode-0 of tAsA and tBsB
constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) +
CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB);
constexpr int tma_transaction_bytes = sizeof(make_tensor_like(tensor<0>(tAsA)))
+ sizeof(make_tensor_like(tensor<0>(tBsB)));
//
// PREFETCH
@ -171,7 +171,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
if ((warp_idx == 0) && lane_predicate)
{
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
}
@ -242,7 +242,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Wait for Consumer to complete consumption
ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase());
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
++write_state;
@ -393,27 +393,25 @@ gemm_tn(int m, int n, int k,
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(2, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
int smemBytes = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
auto* kernel_ptr = &gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>;
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
prob_shape, cta_tiler,
A, tmaA,
B, tmaB,
@ -448,8 +446,10 @@ gemm(char transA, char transB, int m, int n, int k,
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
@ -461,7 +461,7 @@ int main(int argc, char** argv)
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
int m = 512;
if (argc >= 2)
@ -553,10 +553,8 @@ int main(int argc, char** argv)
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#else
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}

View File

@ -41,17 +41,27 @@
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ElementA,
class ElementB,
class SmemLayoutA,
class SmemLayoutB>
struct SharedStorage
{
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> A;
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> B;
};
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TA, class AStride, class ASmemLayout, class TiledCopyA, class S2RAtomA,
class TB, class BStride, class BSmemLayout, class TiledCopyB, class S2RAtomB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, S2RAtomA s2r_atom_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, S2RAtomB s2r_atom_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
@ -95,10 +105,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE)
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles across the threads
@ -143,26 +154,35 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
//
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K)
Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N
// Clear the accumulators
clear(tCrC);
//
// Copy Atom retiling
//
TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x);
Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE)
Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K)
TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE)
Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K)
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
@ -187,12 +207,15 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
print("tXsA : "); print(tXsA); print("\n");
print("tXrA : "); print(tXrA); print("\n");
print("tXsB : "); print(tXsB); print("\n");
print("tXrB : "); print(tXrB); print("\n");
}
#endif
@ -204,8 +227,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
int smem_pipe_write = K_PIPE_MAX-1;
// Pipe slice
Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
Tensor tXsA_p = tXsA(_,_,_,smem_pipe_read);
Tensor tXsB_p = tXsB(_,_,_,smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
@ -217,8 +240,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
__syncthreads();
// Prefetch the first rmem from the first k-tile
copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{}));
copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{}));
copy(s2r_atom_a, tXsA_p(_,_,Int<0>{}), tXrA(_,_,Int<0>{}));
copy(s2r_atom_b, tXsB_p(_,_,Int<0>{}), tXrB(_,_,Int<0>{}));
}
//
@ -243,8 +266,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
if (k_block == K_BLOCK_MAX - 1)
{
// Slice the smem_pipe_read smem
tCsA_p = tCsA(_,_,_,smem_pipe_read);
tCsB_p = tCsB(_,_,_,smem_pipe_read);
tXsA_p = tXsA(_,_,_,smem_pipe_read);
tXsB_p = tXsB(_,_,_,smem_pipe_read);
// Commit the smem for smem_pipe_read
cp_async_wait<K_PIPE_MAX-2>();
@ -253,8 +276,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next));
copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next));
copy(s2r_atom_a, tXsA_p(_,_,k_block_next), tXrA(_,_,k_block_next));
copy(s2r_atom_b, tXsB_p(_,_,k_block_next), tXrB(_,_,k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
@ -268,8 +291,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Advance the smem pipe
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read;
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX-1) ? 0 : smem_pipe_read+1;
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
@ -286,6 +308,126 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
axpby(alpha, tCrC, beta, tCgC);
}
template <class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
cute::half_t const* A, int ldA,
cute::half_t const* B, int ldB,
Beta beta,
cute::half_t * C, int ldC,
cudaStream_t stream = 0)
{
assert(false && "Not implemented");
}
// Setup params for a TN HGEMM
template <class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
cute::half_t const* A, int ldA,
cute::half_t const* B, int ldB,
Beta beta,
cute::half_t * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
// Swizzles for LDSM and 128b k-major loads
auto swizzle_atom = composition(Swizzle<3,3,3>{},
Layout<Shape <_8,Shape <_8, _8>>,
Stride<_8,Stride<_1,_64>>>{});
auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
auto sC = make_layout(make_shape(bM, bN));
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
Layout<Shape< _1,_8>>{}); // Val layout 1x8 k-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
Layout<Shape< _1,_8>>{}); // Val layout 1x8 n-major
TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{},
Layout<Shape<_2,_2>>{}, // 2x2x1 MMA Atoms
Tile<_32,_32,_16>{}); // 32x32x16 Tiled MMA for LDSM
//Copy_Atom<DefaultCopy, half_t> s2r_atom_A;
//Copy_Atom<UniversalCopy<half_t>, half_t> s2r_atom_A;
//Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_A;
//Copy_Atom<SM75_U32x2_LDSM_N, half_t> s2r_atom_A;
Copy_Atom<SM75_U32x4_LDSM_N, half_t> s2r_atom_A;
//Copy_Atom<DefaultCopy, half_t> s2r_atom_B;
//Copy_Atom<UniversalCopy<half_t>, half_t> s2r_atom_B;
//Copy_Atom<SM75_U32x1_LDSM_N, half_t> s2r_atom_B;
//Copy_Atom<SM75_U32x2_LDSM_N, half_t> s2r_atom_B;
Copy_Atom<SM75_U32x4_LDSM_N, half_t> s2r_atom_B;
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
int smem_size = int(sizeof(SharedStorage<cute::half_t, cute::half_t, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
auto kernel_fptr = gemm_device<
decltype(prob_shape), decltype(cta_tiler),
cute::half_t, decltype(dA), decltype(sA), decltype(copyA), decltype(s2r_atom_A),
cute::half_t, decltype(dB), decltype(sB), decltype(copyB), decltype(s2r_atom_B),
cute::half_t, decltype(dC), decltype(sC), decltype(mmaC),
decltype(alpha), decltype(beta)>;
// Set L1 to be SMEM only
cudaFuncSetAttribute(
kernel_fptr,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel_fptr,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA, s2r_atom_A,
B, dB, sB, copyB, s2r_atom_B,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
@ -347,13 +489,14 @@ gemm_nt(int m, int n, int k,
print_latex(mmaC);
#endif
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
A, dA, sA, copyA, AutoVectorizingCopy{},
B, dB, sB, copyB, AutoVectorizingCopy{},
C, dC, sC, mmaC,
alpha, beta);
}
@ -423,13 +566,14 @@ gemm_tn(int m, int n, int k,
print_latex(mmaC);
#endif
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
A, dA, sA, copyA, AutoVectorizingCopy{},
B, dB, sB, copyB, AutoVectorizingCopy{},
C, dC, sC, mmaC,
alpha, beta);
}
@ -470,6 +614,11 @@ int main(int argc, char** argv)
return 0;
}
std::cout << "Using device 0: " << props.name
<< " (SM" << props.major * 10 + props.minor
<< ", " << props.multiProcessorCount
<< ")" << std::endl;
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
@ -490,13 +639,13 @@ int main(int argc, char** argv)
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
using TA = cute::half_t;
using TB = cute::half_t;
using TC = cute::half_t;
using TI = cute::half_t;
TI alpha = 1.0;
TI beta = 0.0;
TI alpha = static_cast<TI>(1.0f);
TI beta = static_cast<TI>(0.0f);
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;

View File

@ -20,3 +20,35 @@
* [04_epilogue_visitor](/examples/python/04_epilogue_visitor.ipynb)
Shows how to fuse elementwise activation functions to GEMMs via the Python Epilogue Visitor interface
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```