190
include/cute/arch/cluster_sm90.hpp
Normal file
190
include/cute/arch/cluster_sm90.hpp
Normal file
@ -0,0 +1,190 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
|
||||
((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))))
|
||||
# define CUTE_ARCH_CLUSTER_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute {
|
||||
|
||||
CUTE_DEVICE void cluster_arrive_relaxed()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_arrive()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.arrive.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_wait()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.wait.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_sync()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
cluster_arrive();
|
||||
cluster_wait();
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 grid size in terms of number of clusters.
|
||||
CUTE_DEVICE dim3 cluster_grid_dims()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return gridDim;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 cluster rank in the grid.
|
||||
CUTE_DEVICE dim3 cluster_id_in_grid()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return blockIdx;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the relative dim3 block rank local to the cluster.
|
||||
CUTE_DEVICE dim3 block_id_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {0,0,0};
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 cluster shape.
|
||||
CUTE_DEVICE dim3 cluster_shape()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {1,1,1};
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get 1D ctaid in a cluster.
|
||||
CUTLASS_DEVICE uint32_t block_rank_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t rank;
|
||||
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
|
||||
return rank;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Set the destination block-ID in cluster for a given SMEM Address
|
||||
CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t result;
|
||||
asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(result)
|
||||
: "r"(smemAddr), "r"(rank));
|
||||
return result;
|
||||
#else
|
||||
return smemAddr;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false.
|
||||
CUTE_HOST_DEVICE uint32_t elect_one_sync()
|
||||
{
|
||||
#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED)
|
||||
uint32_t pred = 0;
|
||||
uint32_t laneid = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b32 %rx;\n"
|
||||
".reg .pred %px;\n"
|
||||
" elect.sync %rx|%px, %2;\n"
|
||||
"@%px mov.s32 %1, 1;\n"
|
||||
" mov.s32 %0, %rx;\n"
|
||||
"}\n"
|
||||
: "+r"(laneid), "+r"(pred)
|
||||
: "r"(0xFFFFFFFF));
|
||||
return pred;
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return (threadIdx.x % 32) == 0;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
71
include/cute/arch/copy.hpp
Normal file
71
include/cute/arch/copy.hpp
Normal file
@ -0,0 +1,71 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/numeric/uint128.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Direct Copy for any type
|
||||
//
|
||||
|
||||
template <class S, class D = S>
|
||||
struct UniversalCopy
|
||||
{
|
||||
using SRegisters = S[1];
|
||||
using DRegisters = D[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
copy(S const& src,
|
||||
D & dst)
|
||||
{
|
||||
dst = src;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Placeholder for the copy algorithm's default, auto-vectorizing behavior
|
||||
//
|
||||
|
||||
struct DefaultCopy
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
};
|
||||
|
||||
using AutoVectorizingCopy = DefaultCopy;
|
||||
|
||||
} // end namespace cute
|
||||
215
include/cute/arch/copy_sm75.hpp
Normal file
215
include/cute/arch/copy_sm75.hpp
Normal file
@ -0,0 +1,215 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
||||
# define CUTE_ARCH_LDSM_SM75_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM75_U32x1_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U32x2_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(dst0), "=r"(dst1)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U32x4_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x2_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x4_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(dst0), "=r"(dst1)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x8_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Legacy LDSM interfaces that aren't very useful
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_ldsm(uint128_t const* const smem_ptr,
|
||||
T* rmem_ptr)
|
||||
{
|
||||
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_ldsm_trans(uint128_t const* const smem_ptr,
|
||||
T* rmem_ptr)
|
||||
{
|
||||
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
138
include/cute/arch/copy_sm80.hpp
Normal file
138
include/cute/arch/copy_sm80.hpp
Normal file
@ -0,0 +1,138 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/// Copy via cp.async with caching at all levels
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEALWAYS
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Copy via cp.async with caching at global level
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEGLOBAL
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_fence()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_wait()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
if constexpr (N == 0) {
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
} else {
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_wait(Int<N>)
|
||||
{
|
||||
return cp_async_wait<N>();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
225
include/cute/arch/copy_sm90.hpp
Normal file
225
include/cute/arch/copy_sm90.hpp
Normal file
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
# define CUTE_ARCH_STSM_SM90_ENABLED
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM90_U32x1_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src,
|
||||
uint128_t & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x2_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x4_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x2_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x4_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x8_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Legacy STSM interfaces that aren't very useful
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_stsm(T const* const rmem_ptr,
|
||||
uint128_t* const smem_ptr)
|
||||
{
|
||||
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_stsm_trans(T const* const rmem_ptr,
|
||||
uint128_t* const smem_ptr)
|
||||
{
|
||||
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
194
include/cute/arch/copy_sm90_desc.hpp
Normal file
194
include/cute/arch/copy_sm90_desc.hpp
Normal file
@ -0,0 +1,194 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
|
||||
#include <cute/container/alignment.hpp>
|
||||
#include <cute/container/bit_field.hpp>
|
||||
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
|
||||
#include <cute/numeric/half.hpp> // to_Format<half_t>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns
|
||||
/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels)
|
||||
/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction)
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Initialize barrier present in shared memory
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int thread_count = 1) // Thread count expected to arrive/wait on this barrier
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(thread_count));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Set the number of bytes transfered per transaction
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
uint32_t bytes) // Number of bytes transfered by per TMA transaction
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier wait
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int phase_bit) // Current phase bit the barrier waiting to flip
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(phase_bit));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier arrive
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b64 state; \n"
|
||||
"mbarrier.arrive.shared.b64 state, [%0];\n"
|
||||
"}\n"
|
||||
:: "r"(smem_int_ptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// TMA Descriptor and utilities
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace TMA {
|
||||
|
||||
enum class SmemSwizzleBits : uint8_t {
|
||||
DISABLE = 0,
|
||||
B32 = 1,
|
||||
B64 = 2,
|
||||
B128 = 3,
|
||||
};
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
|
||||
template <class T>
|
||||
inline CUtensorMapDataType to_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, int8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (std::is_same<T, uint8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (std::is_same<T, uint16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
|
||||
if constexpr (std::is_same<T, uint32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
|
||||
if constexpr (std::is_same<T, uint64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else
|
||||
if constexpr (std::is_same<T, int32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else
|
||||
if constexpr (std::is_same<T, int64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else
|
||||
if constexpr (std::is_same<T, half_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else
|
||||
if constexpr (std::is_same<T, float>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else
|
||||
if constexpr (std::is_same<T, double>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
|
||||
if constexpr (std::is_same<T, bfloat16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
|
||||
if constexpr (std::is_same<T, tfloat32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
|
||||
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
|
||||
}
|
||||
|
||||
inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) {
|
||||
switch (t) {
|
||||
default: assert(false && "Unknown SmemSwizzleBits!");
|
||||
case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // (__CUDACC_VER_MAJOR__ >= 12)
|
||||
} // end namespace TMA
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
using TmaDescriptor = CUtensorMap;
|
||||
#else
|
||||
using TmaDescriptor = struct { char bytes[128]; };
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Initiates a TensorMap Prefetch
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param)
|
||||
asm volatile (
|
||||
"prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(gmem_int_desc)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
552
include/cute/arch/copy_sm90_tma.hpp
Normal file
552
include/cute/arch/copy_sm90_tma.hpp
Normal file
@ -0,0 +1,552 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_1D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_2D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_3D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_4D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6, %7}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_5D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_STORE : Initiates a TMA copy from shared memory to global memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_STORE_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
// Indicate arrival of warp issuing TMA_STORE
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_arrive() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wait on prior N (Count) TMA_STORE instructions to complete
|
||||
template<int Count>
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_wait() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"cp.async.bulk.wait_group.read %0;"
|
||||
:
|
||||
: "n"(Count)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
64
include/cute/arch/mma.hpp
Normal file
64
include/cute/arch/mma.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Direct FMA for any type
|
||||
//
|
||||
|
||||
template <class D, class A = D, class B = A, class C = D>
|
||||
struct UniversalFMA
|
||||
{
|
||||
using DRegisters = D[1];
|
||||
using ARegisters = A[1];
|
||||
using BRegisters = B[1];
|
||||
using CRegisters = C[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
fma(D & d,
|
||||
A const& a,
|
||||
B const& b,
|
||||
C const& c)
|
||||
{
|
||||
// Forward to an ADL/cute free function for these types
|
||||
using cute::fma;
|
||||
fma(d, a, b, c);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
87
include/cute/arch/mma_sm61.hpp
Normal file
87
include/cute/arch/mma_sm61.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
|
||||
# define CUTE_ARCH_MMA_SM61_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM61_DP4A
|
||||
{
|
||||
using DRegisters = int32_t[1];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = int32_t[1];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM61_DP2A
|
||||
{
|
||||
using DRegisters = int32_t[1];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = int32_t[1];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
|
||||
asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
329
include/cute/arch/mma_sm70.hpp
Normal file
329
include/cute/arch/mma_sm70.hpp
Normal file
@ -0,0 +1,329 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
|
||||
# define CUTE_ARCH_MMA_SM70_SUPPORTED
|
||||
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
|
||||
# define CUTE_ARCH_MMA_SM70_ENABLED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// SM70 MMA 884 F16F16F16
|
||||
//
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_NT
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_NN
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_TT
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// SM70 MMA 884 F16F16F32
|
||||
//
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_TN
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_NT
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_NN
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_TT
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
120
include/cute/arch/mma_sm75.hpp
Normal file
120
include/cute/arch/mma_sm75.hpp
Normal file
@ -0,0 +1,120 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
|
||||
# define CUTE_ARCH_MMA_SM75_SUPPORTED
|
||||
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
||||
# define CUTE_ARCH_MMA_SM75_ENABLED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// SM75 MMA 1688 F16F16F32
|
||||
//
|
||||
|
||||
struct SM75_16x8x8_F32F16F16F32_TN
|
||||
{
|
||||
using DRegisters = float[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = float[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6},"
|
||||
"{%7, %8, %9, %10};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// SM75 MMA 8816 S8S8S32
|
||||
//
|
||||
|
||||
struct SM75_8x8x16_S32S8S8S32_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0,
|
||||
uint32_t const& b0,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1},"
|
||||
"{%2},"
|
||||
"{%3},"
|
||||
"{%4, %5};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
: "r"(a0),
|
||||
"r"(b0),
|
||||
"r"(c0), "r"(c1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
2132
include/cute/arch/mma_sm80.hpp
Normal file
2132
include/cute/arch/mma_sm80.hpp
Normal file
File diff suppressed because it is too large
Load Diff
961
include/cute/arch/mma_sm90.hpp
Normal file
961
include/cute/arch/mma_sm90.hpp
Normal file
@ -0,0 +1,961 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
# define CUTE_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x4 TN
|
||||
struct SM90_16x8x4_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[2];
|
||||
using BRegisters = double[1];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1,
|
||||
double const& b0,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6},"
|
||||
"{%7, %8, %9, %10};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1),
|
||||
"d"(b0),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x8 TN
|
||||
struct SM90_16x8x8_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[4];
|
||||
using BRegisters = double[2];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1, double const& a2, double const& a3,
|
||||
double const& b0, double const& b1,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
||||
"d"(b0), "d"(b1),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x16 TN
|
||||
struct SM90_16x8x16_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[8];
|
||||
using BRegisters = double[4];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1, double const& a2, double const& a3,
|
||||
double const& a4, double const& a5, double const& a6, double const& a7,
|
||||
double const& b0, double const& b1, double const& b2, double const& b3,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7, %8, %9, %10, %11},"
|
||||
"{%12, %13, %14, %15},"
|
||||
"{%16, %17, %18, %19};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
||||
"d"(a4), "d"(a5), "d"(a6), "d"(a7),
|
||||
"d"(b0), "d"(b1), "d"(b2), "d"(b3),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x4 TN
|
||||
struct SM90_16x8x4_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[2];
|
||||
using BRegisters = complex<double>[1];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& b0,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(),
|
||||
b0.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(),
|
||||
b0.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(),
|
||||
b0.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(),
|
||||
b0.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x8 TN
|
||||
struct SM90_16x8x8_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[4];
|
||||
using BRegisters = complex<double>[2];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& a2, complex<double> const& a3,
|
||||
complex<double> const& b0, complex<double> const& b1,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
b0.real(), b1.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
||||
b0.real(), b1.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
||||
b0.imag(), b1.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
b0.imag(), b1.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x16 TN
|
||||
struct SM90_16x8x16_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[8];
|
||||
using BRegisters = complex<double>[4];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& a2, complex<double> const& a3,
|
||||
complex<double> const& a4, complex<double> const& a5,
|
||||
complex<double> const& a6, complex<double> const& a7,
|
||||
complex<double> const& b0, complex<double> const& b1,
|
||||
complex<double> const& b2, complex<double> const& b3,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
a4.real(), a5.real(), a6.real(), a7.real(),
|
||||
b0.real(), b1.real(), b2.real(), b3.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
||||
a4.imag(), a5.imag(), a6.imag(), a7.imag(),
|
||||
b0.real(), b1.real(), b2.real(), b3.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
||||
-a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(),
|
||||
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
a4.real(), a5.real(), a6.real(), a7.real(),
|
||||
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cute/arch/mma_sm90_desc.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
namespace GMMA {
|
||||
|
||||
template<
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class TileShape_MNK,
|
||||
GMMA::Major MajorA = GMMA::Major::K,
|
||||
GMMA::Major MajorB = GMMA::Major::K,
|
||||
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
||||
// But most commonly leave empty for defaults
|
||||
>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
ss_op_selector()
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
||||
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
||||
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
||||
auto Tile_N = size<1>(TileShape_MNK{});
|
||||
|
||||
// FP16 accumulator
|
||||
if constexpr (std::is_same_v<ElementC, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// Dispatch against the Tile N mode size
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// FP32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, float>) {
|
||||
|
||||
// FP16 inputs
|
||||
if constexpr (std::is_same_v<ElementA, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// BF16 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// TF32 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
// S32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, int32_t>) {
|
||||
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
// ElementA == int8_t && ElementB == int8_t
|
||||
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == int8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == int8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown accumulator type
|
||||
else {
|
||||
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class TileShape_MNK,
|
||||
GMMA::Major MajorA = GMMA::Major::K,
|
||||
GMMA::Major MajorB = GMMA::Major::K,
|
||||
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
||||
// But most commonly leave empty for defaults
|
||||
>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
rs_op_selector()
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
||||
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
||||
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
||||
static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout.");
|
||||
auto Tile_N = size<1>(TileShape_MNK{});
|
||||
|
||||
// FP16 accumulator
|
||||
if constexpr (std::is_same_v<ElementC, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// Dispatch against the Tile N mode size
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// FP32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, float>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// FP16 inputs
|
||||
if constexpr (std::is_same_v<ElementA, half_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// BF16 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// TF32 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
// S32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, int32_t>) {
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
// ElementA == int8_t && ElementB == int8_t
|
||||
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == int8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == int8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown accumulator type
|
||||
else {
|
||||
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
||||
}
|
||||
}
|
||||
} // end namespace GMMA
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
131
include/cute/arch/mma_sm90_desc.hpp
Normal file
131
include/cute/arch/mma_sm90_desc.hpp
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
# define CUTE_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// GMMA Descriptor and utilities
|
||||
|
||||
// GMMA enums and utilities
|
||||
namespace GMMA
|
||||
{
|
||||
|
||||
enum class LayoutType : uint8_t {
|
||||
INTERLEAVE = 0,
|
||||
B128 = 1,
|
||||
B64 = 2,
|
||||
B32 = 3,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) {
|
||||
switch (t) {
|
||||
case LayoutType::INTERLEAVE: return "INTERLEAVE";
|
||||
case LayoutType::B128: return "B128";
|
||||
case LayoutType::B64: return "B64";
|
||||
case LayoutType::B32: return "B32";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Output operator for all enums in this namespace
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) {
|
||||
char const* s = to_string(t);
|
||||
if (s) {
|
||||
std::operator<<(os, s); // Explicit call to avoid ambiguity
|
||||
} else {
|
||||
os.setstate(std::ios_base::failbit);
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
} // end namespace GMMA
|
||||
|
||||
union GmmaDescriptor
|
||||
{
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
struct {
|
||||
// start_address, bit [0,14), 4LSB not included
|
||||
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// leading dimension byte offset, bit [16,30), 4LSB not included
|
||||
// For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED
|
||||
// Unused for all SWIZZLE_* layouts (and assumed to be 1)
|
||||
// For T: This is the stride from the first 8 rows to the next 8 rows.
|
||||
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// stride dimension byte offset, bit [32,46), 4LSB not included
|
||||
// For N: This is the stride from the first 8 rows to the next 8 rows.
|
||||
// For T: This is the stride fro mthe first 8 cols to the next 8 cols.
|
||||
uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// base_offset, bit [49,52)
|
||||
// Valid only for SWIZZLE_128B and SWIZZLE_64B
|
||||
uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused
|
||||
// layout type, bit [62,64)
|
||||
// SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
|
||||
uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8)
|
||||
};
|
||||
|
||||
// Decay to a uint64_t
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator uint64_t() const noexcept { return desc_; }
|
||||
|
||||
// Printer
|
||||
CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t)
|
||||
{
|
||||
printf("GmmaDescriptor: 0x%016lx\n", t.desc_);
|
||||
printf(" start_addr : 0x%04x\n", t.start_address_);
|
||||
printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_);
|
||||
printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_);
|
||||
printf(" base_offset: 0x%01x\n", t.base_offset_);
|
||||
printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast<GMMA::LayoutType>(t.layout_type_)));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
12265
include/cute/arch/mma_sm90_gmma.hpp
Normal file
12265
include/cute/arch/mma_sm90_gmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
178
include/cute/arch/util.hpp
Normal file
178
include/cute/arch/util.hpp
Normal file
@ -0,0 +1,178 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
|
||||
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
extern "C" {
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
// Clients should not call it directly.
|
||||
CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/// CUTE helper to cast SMEM pointer to unsigned
|
||||
CUTE_HOST_DEVICE
|
||||
uint32_t
|
||||
cast_smem_ptr_to_uint(void const* const ptr)
|
||||
{
|
||||
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
||||
// the previous internal intrinsics if they are available.
|
||||
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
|
||||
//
|
||||
// This NVVM intrinsic converts an address in shared memory to a plain
|
||||
// unsigned integer. This is necessary to pass to shared memory instructions
|
||||
// in inline PTX.
|
||||
//
|
||||
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
|
||||
//
|
||||
//__device__ size_t __cvta_generic_to_shared(void* ptr);
|
||||
|
||||
/// CUTE helper to get SMEM pointer
|
||||
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
|
||||
uint32_t smem_ptr;
|
||||
|
||||
asm(
|
||||
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr) : "l"(ptr));
|
||||
|
||||
return smem_ptr;
|
||||
|
||||
#else
|
||||
|
||||
|
||||
(void) ptr;
|
||||
printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n");
|
||||
return 0;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// Utility for pointer interfaces
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Fn,
|
||||
class PtrS, int... Is,
|
||||
class PtrD, int... Id>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrS&& s, int_sequence<Is...>,
|
||||
PtrD&& d, int_sequence<Id...>)
|
||||
{
|
||||
return fn(s[Is]..., d[Id]...);
|
||||
}
|
||||
|
||||
template <class Fn,
|
||||
class PtrA, int... Ia,
|
||||
class PtrB, int... Ib,
|
||||
class PtrC, int... Ic>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrA&& a, int_sequence<Ia...>,
|
||||
PtrB&& b, int_sequence<Ib...>,
|
||||
PtrC&& c, int_sequence<Ic...>)
|
||||
{
|
||||
return fn(a[Ia]..., b[Ib]..., c[Ic]...);
|
||||
}
|
||||
|
||||
template <class Fn,
|
||||
class PtrD, int... Id,
|
||||
class PtrA, int... Ia,
|
||||
class PtrB, int... Ib,
|
||||
class PtrC, int... Ic>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrD&& d, int_sequence<Id...>,
|
||||
PtrA&& a, int_sequence<Ia...>,
|
||||
PtrB&& b, int_sequence<Ib...>,
|
||||
PtrC&& c, int_sequence<Ic...>)
|
||||
{
|
||||
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <int SRegCount, int DRegCount,
|
||||
class Fn, class PtrS, class PtrD>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrS&& s, PtrD&& d)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
s, make_int_sequence<SRegCount>{},
|
||||
d, make_int_sequence<DRegCount>{});
|
||||
}
|
||||
|
||||
template <int ARegCount, int BRegCount, int CRegCount,
|
||||
class Fn, class PtrA, class PtrB, class PtrC>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
a, make_int_sequence<ARegCount>{},
|
||||
b, make_int_sequence<BRegCount>{},
|
||||
c, make_int_sequence<CRegCount>{});
|
||||
}
|
||||
|
||||
template <int DRegCount, int ARegCount, int BRegCount, int CRegCount,
|
||||
class Fn, class PtrD, class PtrA, class PtrB, class PtrC>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
d, make_int_sequence<DRegCount>{},
|
||||
a, make_int_sequence<ARegCount>{},
|
||||
b, make_int_sequence<BRegCount>{},
|
||||
c, make_int_sequence<CRegCount>{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
Reference in New Issue
Block a user