diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp index 0e4821b2..0c34bc73 100644 --- a/include/cute/arch/copy_sm75.hpp +++ b/include/cute/arch/copy_sm75.hpp @@ -60,6 +60,12 @@ #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 #endif +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 +#else + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 0 +#endif + namespace cute { @@ -183,6 +189,25 @@ struct SM75_U16x8_LDSM_T } }; +struct SM75_U32x1_MOVM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t src, + uint32_t &dst) + { +#if CUTE_ARCH_MOVM_SM75_ACTIVATED + asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "=r"(dst) + : "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use movmatrix without CUTE_ARCH_MOVM_SM75_ACTIVATED."); +#endif + } +}; + // // Legacy LDSM interfaces that aren't very useful // diff --git a/include/cute/atom/copy_traits_sm75.hpp b/include/cute/atom/copy_traits_sm75.hpp index 416938b1..fdd00081 100644 --- a/include/cute/atom/copy_traits_sm75.hpp +++ b/include/cute/atom/copy_traits_sm75.hpp @@ -140,4 +140,21 @@ struct Copy_Traits using RefLayout = DstLayout; }; +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + } // end namespace cute diff --git a/test/unit/cute/turing/CMakeLists.txt b/test/unit/cute/turing/CMakeLists.txt index f6c6f64b..005feafd 100644 --- a/test/unit/cute/turing/CMakeLists.txt +++ b/test/unit/cute/turing/CMakeLists.txt @@ -29,4 +29,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_turing cooperative_gemm.cu + movm.cu ) diff --git a/test/unit/cute/turing/movm.cu b/test/unit/cute/turing/movm.cu new file mode 100644 index 00000000..932c3e7f --- /dev/null +++ b/test/unit/cute/turing/movm.cu @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +#include + +#include + +using namespace cute; + +__global__ void +movm_test_device(uint16_t* g_in, uint16_t* g_out) +{ + int tid = threadIdx.x; + + // load input gmem -> register + uint32_t reg = reinterpret_cast(g_in)[tid]; + + // do two movmatrix calls (transpose twice => identity) + uint32_t tmp = 0; + uint32_t dst = 0; + SM75_U32x1_MOVM_T::copy(reg, tmp); + SM75_U32x1_MOVM_T::copy(tmp, dst); + + // store result + reinterpret_cast(g_out)[tid] = dst; +} + +template +__global__ void +movm_test_device_cute(uint16_t* g_in, uint16_t* g_out, + TiledCopy tiled_copy, GmemLayout gmem_layout) +{ + using namespace cute; + + auto t_g_in = make_tensor(make_gmem_ptr(reinterpret_cast(g_in)), gmem_layout); + auto t_g_out = make_tensor(make_gmem_ptr(reinterpret_cast(g_out)), gmem_layout); + + int tid = threadIdx.x; + + auto thr_copy = tiled_copy.get_thread_slice(tid); + + auto tXgS = thr_copy.partition_S(t_g_in); + auto tXgD = thr_copy.partition_D(t_g_out); + + // Register tensors for intermediate and output data + auto tXrS = make_tensor(shape(tXgS)); // src + auto tXrT = make_tensor(shape(tXgS)); // tmp + auto tXrD = make_tensor(shape(tXgD)); // dst + clear(tXrS); + clear(tXrT); + clear(tXrD); + + // Load gmem -> registers + for (int i = 0; i < size(tXrS); ++i) { + tXrS(i) = tXgS(i); + } + + // do two movmatrix calls for identity + copy(tiled_copy, tXrS, tXrT); + copy(tiled_copy, tXrT, tXrD); + + // Store registers -> gmem + for (int i = 0; i < size(tXrD); ++i) { + tXgD(i) = tXrD(i); + } +} + +TEST(SM75_CuTe_Turing, Movm) +{ + constexpr int count = 1024; + + thrust::host_vector h_in(count); + for (int i = 0; i < count; ++i) { + h_in[i] = uint16_t(i); + } + thrust::device_vector d_in = h_in; + + // + // Direct MOVM + // + + { + thrust::device_vector d_out(count); + movm_test_device<<<1, 32>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data())); + thrust::host_vector h_out = d_out; + // applied movmatrix twice so result should equal input + for (int i = 0; i < 64; ++i) { + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("MOVM movm_test_device SUCCESS\n"); + } + + // + // CuTe MOVM + // + + { + thrust::device_vector d_out(count); + + auto gmem_layout = Layout, + Stride< _1,_32>>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout>{}, + Layout>{}); + + movm_test_device_cute<<<1, int(size(tiled_copy))>>>( + thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + tiled_copy, + gmem_layout); + thrust::host_vector h_out = d_out; + for (int i = 0; i < (size(gmem_layout)*2); ++i) { + EXPECT_EQ(h_out[i], h_in[i]); + } + CUTLASS_TRACE_HOST("CuTe MOVM SUCCESS\n"); + } + + CUTLASS_TRACE_HOST("PASS"); +}