cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -33,6 +33,7 @@
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/container/type_list.hpp>
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
@ -277,34 +278,13 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f)
|
||||
// find and find_if
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (decltype(f(get<I>(t)))::value) {
|
||||
return cute::C<I>{};
|
||||
} else
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return cute::C<I+1>{};
|
||||
} else {
|
||||
return find_if(t, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::find_if(t, f, tuple_seq<T>{});
|
||||
return detail::tapply(t, f, [] (auto... a) { return cute::C<find_true_v<decltype(a)::value...>>{}; }, tuple_seq<T>{});
|
||||
} else {
|
||||
return cute::C<decltype(f(t))::value ? 0 : 1>{};
|
||||
}
|
||||
@ -326,7 +306,7 @@ auto
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
|
||||
return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
@ -340,7 +320,7 @@ auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
|
||||
return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
|
||||
|
||||
@ -72,6 +72,27 @@
|
||||
# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED))
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM90_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
|
||||
# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED))
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED))
|
||||
# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED
|
||||
#endif
|
||||
@ -91,8 +112,11 @@
|
||||
#endif
|
||||
|
||||
// {add, mul, fma}.f32x2 PTX
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
|
||||
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
// Enable CuTe MMA Atoms
|
||||
# define CUTE_ARCH_FFMA2_SM100_ENABLED
|
||||
// Enable f32x2 PTX generation
|
||||
# define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)
|
||||
@ -109,3 +133,37 @@
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
|
||||
# define CUTE_ARCH_LDSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_TMEM_ENABLED
|
||||
# define CUTE_ARCH_TMA_SM100_ENABLED
|
||||
# define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)
|
||||
# define CUTE_ARCH_LDSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_TMEM_ENABLED
|
||||
# define CUTE_ARCH_TMA_SM100_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)
|
||||
# define CUTE_ARCH_LDSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM100A_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\
|
||||
defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\
|
||||
defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED))
|
||||
# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9))
|
||||
# define CUTE_ARCH_LOAD256_SM100A_ENABLED
|
||||
# define CUTE_ARCH_STORE256_SM100A_ENABLED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// {add, mul, fma}.f32x2 PTX
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
|
||||
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
@ -28,10 +28,6 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
@ -316,17 +312,14 @@ struct SM100_U8x16_STSM_T
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// UTCCP PTX definitions
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
namespace SM100::TMEM::UTCCP {
|
||||
|
||||
// 128 data path lanes, 256-bit pattern, 1cta mode
|
||||
struct SM100_UTCCP_128dp256bit_1cta
|
||||
{
|
||||
@ -558,21 +551,19 @@ struct SM100_UTCCP_2x64dp128bitlw0123_2cta
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
} // end namespace SM100::TMEM::UTCCP
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
namespace SM100::TMEM::LOAD {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// TMEM_LOAD PTX definitions
|
||||
// TMEM LOAD PTX definitions
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -3945,7 +3936,6 @@ struct SM100_TMEM_LOAD_32dp32b128x
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read
|
||||
@ -4065,9 +4055,21 @@ struct SM100_TMEM_LOAD_32dp32b128x_16b
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace SM100::TMEM::LOAD
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace SM100::TMEM::STORE {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// TMEM_STORE PTX definitions
|
||||
// TMEM STORE PTX definitions
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -4086,8 +4088,8 @@ struct SM100_TMEM_STORE_16dp256b1x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4110,8 +4112,8 @@ struct SM100_TMEM_STORE_16dp256b1x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4136,8 +4138,8 @@ struct SM100_TMEM_STORE_16dp256b2x
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -4163,8 +4165,8 @@ struct SM100_TMEM_STORE_16dp256b2x_16b
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -4194,8 +4196,8 @@ struct SM100_TMEM_STORE_16dp256b4x
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4227,8 +4229,8 @@ struct SM100_TMEM_STORE_16dp256b4x_16b
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4268,8 +4270,8 @@ struct SM100_TMEM_STORE_16dp256b8x
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4313,8 +4315,8 @@ struct SM100_TMEM_STORE_16dp256b8x_16b
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4374,8 +4376,8 @@ struct SM100_TMEM_STORE_16dp256b16x
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4443,8 +4445,8 @@ struct SM100_TMEM_STORE_16dp256b16x_16b
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4544,8 +4546,8 @@ struct SM100_TMEM_STORE_16dp256b32x
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -4661,8 +4663,8 @@ struct SM100_TMEM_STORE_16dp256b32x_16b
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -4716,8 +4718,8 @@ struct SM100_TMEM_STORE_16dp128b1x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4740,8 +4742,8 @@ struct SM100_TMEM_STORE_16dp128b1x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4764,8 +4766,8 @@ struct SM100_TMEM_STORE_16dp128b2x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4788,8 +4790,8 @@ struct SM100_TMEM_STORE_16dp128b2x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -4814,8 +4816,8 @@ struct SM100_TMEM_STORE_16dp128b4x
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -4841,8 +4843,8 @@ struct SM100_TMEM_STORE_16dp128b4x_16b
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -4872,8 +4874,8 @@ struct SM100_TMEM_STORE_16dp128b8x
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4905,8 +4907,8 @@ struct SM100_TMEM_STORE_16dp128b8x_16b
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4946,8 +4948,8 @@ struct SM100_TMEM_STORE_16dp128b16x
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -4991,8 +4993,8 @@ struct SM100_TMEM_STORE_16dp128b16x_16b
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5052,8 +5054,8 @@ struct SM100_TMEM_STORE_16dp128b32x
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5121,8 +5123,8 @@ struct SM100_TMEM_STORE_16dp128b32x_16b
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5222,8 +5224,8 @@ struct SM100_TMEM_STORE_16dp128b64x
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -5339,8 +5341,8 @@ struct SM100_TMEM_STORE_16dp128b64x_16b
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -5394,8 +5396,8 @@ struct SM100_TMEM_STORE_16dp64b1x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32"
|
||||
"[%0],"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5418,8 +5420,8 @@ struct SM100_TMEM_STORE_16dp64b1x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5442,8 +5444,8 @@ struct SM100_TMEM_STORE_16dp64b2x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5466,8 +5468,8 @@ struct SM100_TMEM_STORE_16dp64b2x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5490,8 +5492,8 @@ struct SM100_TMEM_STORE_16dp64b4x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5514,8 +5516,8 @@ struct SM100_TMEM_STORE_16dp64b4x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -5540,8 +5542,8 @@ struct SM100_TMEM_STORE_16dp64b8x
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -5567,8 +5569,8 @@ struct SM100_TMEM_STORE_16dp64b8x_16b
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -5598,8 +5600,8 @@ struct SM100_TMEM_STORE_16dp64b16x
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5631,8 +5633,8 @@ struct SM100_TMEM_STORE_16dp64b16x_16b
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5672,8 +5674,8 @@ struct SM100_TMEM_STORE_16dp64b32x
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5717,8 +5719,8 @@ struct SM100_TMEM_STORE_16dp64b32x_16b
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5778,8 +5780,8 @@ struct SM100_TMEM_STORE_16dp64b64x
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5847,8 +5849,8 @@ struct SM100_TMEM_STORE_16dp64b64x_16b
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -5948,8 +5950,8 @@ struct SM100_TMEM_STORE_16dp64b128x
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -6065,8 +6067,8 @@ struct SM100_TMEM_STORE_16dp64b128x_16b
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -6120,8 +6122,8 @@ struct SM100_TMEM_STORE_16dp32b1x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32"
|
||||
"[%0] , 1,"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6144,8 +6146,8 @@ struct SM100_TMEM_STORE_16dp32b1x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32"
|
||||
"[%0] , 2,"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6168,8 +6170,8 @@ struct SM100_TMEM_STORE_16dp32b2x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32"
|
||||
"[%0] , 2,"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6192,8 +6194,8 @@ struct SM100_TMEM_STORE_16dp32b2x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32"
|
||||
"[%0] , 4,"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6216,8 +6218,8 @@ struct SM100_TMEM_STORE_16dp32b4x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32"
|
||||
"[%0] , 4,"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6240,8 +6242,8 @@ struct SM100_TMEM_STORE_16dp32b4x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32"
|
||||
"[%0] , 8,"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6266,8 +6268,8 @@ struct SM100_TMEM_STORE_16dp32b8x
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32"
|
||||
"[%0] , 8,"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -6293,8 +6295,8 @@ struct SM100_TMEM_STORE_16dp32b8x_16b
|
||||
asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32"
|
||||
"[%0] , 16,"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -6324,8 +6326,8 @@ struct SM100_TMEM_STORE_16dp32b16x
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6357,8 +6359,8 @@ struct SM100_TMEM_STORE_16dp32b16x_16b
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6398,8 +6400,8 @@ struct SM100_TMEM_STORE_16dp32b32x
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6443,8 +6445,8 @@ struct SM100_TMEM_STORE_16dp32b32x_16b
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6504,8 +6506,8 @@ struct SM100_TMEM_STORE_16dp32b64x
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6573,8 +6575,8 @@ struct SM100_TMEM_STORE_16dp32b64x_16b
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -6674,8 +6676,8 @@ struct SM100_TMEM_STORE_16dp32b128x
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -6791,8 +6793,8 @@ struct SM100_TMEM_STORE_16dp32b128x_16b
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -6846,8 +6848,8 @@ struct SM100_TMEM_STORE_32dp32b1x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32"
|
||||
"[%0],"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6870,8 +6872,8 @@ struct SM100_TMEM_STORE_32dp32b1x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1};\n"
|
||||
:
|
||||
"{%1};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6894,8 +6896,8 @@ struct SM100_TMEM_STORE_32dp32b2x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6918,8 +6920,8 @@ struct SM100_TMEM_STORE_32dp32b2x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
"{%1, %2};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6942,8 +6944,8 @@ struct SM100_TMEM_STORE_32dp32b4x
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6966,8 +6968,8 @@ struct SM100_TMEM_STORE_32dp32b4x_16b
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
"{%1, %2, %3, %4};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) );
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED.");
|
||||
@ -6992,8 +6994,8 @@ struct SM100_TMEM_STORE_32dp32b8x
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -7019,8 +7021,8 @@ struct SM100_TMEM_STORE_32dp32b8x_16b
|
||||
asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32"
|
||||
"[%0],"
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
"%5, %6, %7, %8};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3),
|
||||
"r"(src4), "r"(src5), "r"(src6), "r"(src7) );
|
||||
#else
|
||||
@ -7050,8 +7052,8 @@ struct SM100_TMEM_STORE_32dp32b16x
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7083,8 +7085,8 @@ struct SM100_TMEM_STORE_32dp32b16x_16b
|
||||
"{%1, %2, %3, %4,"
|
||||
"%5, %6, %7, %8,"
|
||||
"%9, %10, %11, %12,"
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
"%13, %14, %15, %16};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7124,8 +7126,8 @@ struct SM100_TMEM_STORE_32dp32b32x
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7169,8 +7171,8 @@ struct SM100_TMEM_STORE_32dp32b32x_16b
|
||||
"%17, %18, %19, %20,"
|
||||
"%21, %22, %23, %24,"
|
||||
"%25, %26, %27, %28,"
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
"%29, %30, %31, %32};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7230,8 +7232,8 @@ struct SM100_TMEM_STORE_32dp32b64x
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7299,8 +7301,8 @@ struct SM100_TMEM_STORE_32dp32b64x_16b
|
||||
"%49, %50, %51, %52,"
|
||||
"%53, %54, %55, %56,"
|
||||
"%57, %58, %59, %60,"
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
"%61, %62, %63, %64};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03),
|
||||
"r"(src04), "r"(src05), "r"(src06), "r"(src07),
|
||||
"r"(src08), "r"(src09), "r"(src10), "r"(src11),
|
||||
@ -7400,8 +7402,8 @@ struct SM100_TMEM_STORE_32dp32b128x
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -7517,8 +7519,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b
|
||||
"%113, %114, %115, %116,"
|
||||
"%117, %118, %119, %120,"
|
||||
"%121, %122, %123, %124,"
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
"%125, %126, %127, %128};\n"
|
||||
:
|
||||
: "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003),
|
||||
"r"(src004), "r"(src005), "r"(src006), "r"(src007),
|
||||
"r"(src008), "r"(src009), "r"(src010), "r"(src011),
|
||||
@ -7561,7 +7563,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
} // namespace SM100::TMEM::STORE
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -29,7 +29,6 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
|
||||
#pragma once
|
||||
@ -37,6 +36,48 @@
|
||||
#include <cute/arch/config.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
#include <cute/arch/simd_sm100.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
struct SM100_2x1x1_F32F32F32F32 {
|
||||
using DRegisters = float2[1];
|
||||
using ARegisters = float2[1];
|
||||
using BRegisters = float[1];
|
||||
using CRegisters = float2[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float2 & d01,
|
||||
float2 const& a01,
|
||||
float const& b0,
|
||||
float2 const& c01)
|
||||
{
|
||||
#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED)
|
||||
cute::fma(d01, a01, make_float2(b0, b0), c01);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_2x1x1_F32F32F32F32 without CUTE_ARCH_FLOAT2_MATH_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_1x2x1_F32F32F32F32 {
|
||||
using DRegisters = float2[1];
|
||||
using ARegisters = float[1];
|
||||
using BRegisters = float2[1];
|
||||
using CRegisters = float2[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float2 & d01,
|
||||
float const& a0,
|
||||
float2 const& b01,
|
||||
float2 const& c01)
|
||||
{
|
||||
#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED)
|
||||
cute::fma(d01, make_float2(a0, a0), b01, c01);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_1x2x1_F32F32F32F32 without CUTE_ARCH_FFMA2_SM100_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
|
||||
@ -28,19 +28,34 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/atom/copy_traits_sm100.hpp>
|
||||
|
||||
#include <cutlass/pipeline/sm90_pipeline.hpp>
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/pointer.hpp>
|
||||
|
||||
namespace cute::TMEM {
|
||||
|
||||
//
|
||||
// TMEM Addressing Constants
|
||||
//
|
||||
|
||||
// 128 DP x 512 COL x uint32_t-addressing
|
||||
using MAX_CAPACITY_BITS = Int<128*512*32>;
|
||||
|
||||
// TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t)
|
||||
using DP_b = cute::constant<int32_t, (1 << 21)>;
|
||||
|
||||
// TMEM DP stride in type-T addressing
|
||||
template <class T = uint32_t>
|
||||
using DP = cute::constant<int32_t, shiftl((1 << 16), tmem_ptr<T>::OffsetShift)>;
|
||||
|
||||
//
|
||||
// TMEM Allocators
|
||||
//
|
||||
|
||||
// All operations of this class require that only a single warp uniformly participates
|
||||
class Allocator1Sm {
|
||||
public:
|
||||
@ -77,8 +92,8 @@ public:
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
"}"
|
||||
:
|
||||
: "r"(tmem_ptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
@ -130,7 +145,7 @@ public:
|
||||
}
|
||||
|
||||
/**
|
||||
* Frees the TMEM corresponding to the pointer and slice count provided.
|
||||
* Frees the TMEM corresponding to the pointer and slice count provided.
|
||||
* Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices.
|
||||
* @param tmem_ptr Base address of the TMEM address space being freed.
|
||||
* @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2.
|
||||
@ -146,8 +161,8 @@ public:
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
"}"
|
||||
:
|
||||
: "r"(tmem_ptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -647,7 +647,7 @@ make_tma_atom_im2col(CopyOp,
|
||||
gtensor_cwhdn,
|
||||
range_c,
|
||||
range_whdn,
|
||||
detail::get_swizzle_portion(slayout),
|
||||
get_swizzle_portion(slayout),
|
||||
tma_layout_vt,
|
||||
lower_corner_whd,
|
||||
upper_corner_whd,
|
||||
|
||||
@ -37,10 +37,13 @@
|
||||
#include <cute/arch/mma_sm100.hpp>
|
||||
#include <cute/arch/mma_sm100_desc.hpp>
|
||||
#include <cute/arch/mma_sm100_umma.hpp>
|
||||
#include <cute/atom/copy_traits_sm100.hpp> // cute::TMEM::
|
||||
#include <cute/arch/tmem_allocator_sm100.hpp> // cute::TMEM::
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
#include <cute/atom/mma_traits_sm90_gmma.hpp> // cute::GMMA::
|
||||
#include <cute/atom/mma_traits_sm90_gmma_sparse.hpp> // cute::GMMA::
|
||||
#include <cute/atom/copy_traits_sm100.hpp> // UTCCP smem desc
|
||||
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Check that aggregate initialization in .with() initializes all fields
|
||||
@ -417,6 +420,9 @@ constexpr auto get_utccp_smem_desc_tensor(Tensor<TEngine, TLayout> const& smem_u
|
||||
|
||||
namespace UMMA {
|
||||
|
||||
// Import TMEM constants
|
||||
namespace TMEM = cute::TMEM;
|
||||
|
||||
enum class TmemAllocMode {
|
||||
// Default allocation mode.
|
||||
// If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be
|
||||
@ -3053,7 +3059,7 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
||||
static_assert(cute::sizeof_bits_v<a_type> <= 8 && cute::sizeof_bits_v<b_type> <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types");
|
||||
static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA.");
|
||||
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256.");
|
||||
|
||||
|
||||
using FrgTypeA = UMMA::smem_desc<a_major>;
|
||||
using FrgTypeB = UMMA::smem_desc<b_major>;
|
||||
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
|
||||
|
||||
@ -51,8 +51,8 @@
|
||||
// but do _not_ include references like int& or float&.
|
||||
// (See std::tie for an example of a tuple of references.)
|
||||
//
|
||||
// Standard-layout types preserve ABI across host-device boundaries.
|
||||
// They are safe to use as device kernel parameters.
|
||||
// Standard-layout types preserve ABI across host-device boundaries. They are safe to use as device kernel parameters.
|
||||
// The standard-layout requirement prevents a more common EBO-based implemented of cute::tuple.
|
||||
//
|
||||
// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
|
||||
// the conversion SFINAE, special overloading, and avoiding cvref template types.
|
||||
@ -62,12 +62,15 @@
|
||||
namespace cute
|
||||
{
|
||||
|
||||
namespace detail
|
||||
template <class... T>
|
||||
struct tuple;
|
||||
|
||||
namespace eso
|
||||
{
|
||||
|
||||
// ESO stands for "empty structure optimization."
|
||||
// We use this technique to ensure that cute::tuple
|
||||
// doesn't waste space storing template arguments that have no data (like integral_constant).
|
||||
// We use this technique to ensure that cute::tuple doesn't waste space
|
||||
// storing template arguments that have no data (like integral_constant).
|
||||
// Empty types in the template argument list are not even constructed,
|
||||
// and do not have unique element addresses. Calling `get`
|
||||
// constructs and returns an instance of an empty type on demand.
|
||||
@ -131,94 +134,92 @@ struct ESO<false, false, First, Rest...> {
|
||||
};
|
||||
|
||||
// Get Nth value from ESO
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
template <class R, size_t N, class S>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>>>
|
||||
getv(ESO<F, R, T, Rest...> const&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> const&>
|
||||
getv(ESO<F, R, T, Rest...> const& s)
|
||||
R
|
||||
getr(S&& s) noexcept
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T const&>(s.first_);
|
||||
return static_cast<S&&>(s).first_;
|
||||
} else {
|
||||
return getv<N-1>(s.rest_);
|
||||
return getr<R,N-1>(static_cast<S&&>(s).rest_);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
// Compilers disagree on decltype(auto), so these implementations avoid it at cost
|
||||
template <size_t N, bool F, bool R, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &>
|
||||
getv(ESO<F, R, T, Rest...>& s)
|
||||
cute::conditional_t<cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>>,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>> const&>
|
||||
getv_cr(ESO<F, R, T...> const& s) noexcept
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T&>(s.first_);
|
||||
if constexpr (cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value) {
|
||||
return {};
|
||||
} else {
|
||||
return getv<N-1>(s.rest_);
|
||||
return getr<cute::tuple_element_t<N, cute::tuple<T...>> const&, N>(s);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <size_t N, bool F, bool R, class T, class... Rest>
|
||||
template <size_t N, bool F, bool R, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
|
||||
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &&>
|
||||
getv(ESO<F, R, T, Rest...>&& s)
|
||||
cute::conditional_t<cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>>,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>> &>
|
||||
getv_r(ESO<F, R, T...>& s) noexcept
|
||||
{
|
||||
if constexpr (N == 0) {
|
||||
return static_cast<T&&>(s.first_);
|
||||
if constexpr (cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value) {
|
||||
return {};
|
||||
} else {
|
||||
return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_));
|
||||
return getr<cute::tuple_element_t<N, cute::tuple<T...>> &, N>(s);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class X, size_t N,
|
||||
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
|
||||
template <size_t N, bool F, bool R, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
|
||||
cute::conditional_t<cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>>,
|
||||
cute::tuple_element_t<N, cute::tuple<T...>> &&>
|
||||
getv_rr(ESO<F, R, T...>&& s) noexcept
|
||||
{
|
||||
if constexpr (cute::is_same_v<X, First>) {
|
||||
return C<N>{};
|
||||
} else
|
||||
if constexpr (sizeof...(Rest) == 0) {
|
||||
return C<N+1>{};
|
||||
} else
|
||||
if constexpr (IsRestEmpty) {
|
||||
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
|
||||
if constexpr (cute::is_empty<cute::tuple_element_t<N, cute::tuple<T...>>>::value) {
|
||||
return {};
|
||||
} else {
|
||||
return cute::detail::findt<X, N+1>(t.rest_);
|
||||
return getr<cute::tuple_element_t<N, cute::tuple<T...>> &&, N>(static_cast<ESO<F, R, T...>&&>(s));
|
||||
}
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
} // end namespace eso
|
||||
|
||||
template <class... T>
|
||||
struct tuple : detail::ESO_t<T...>
|
||||
struct tuple : eso::ESO_t<T...>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple(T const&... t) : detail::ESO_t<T...>(t...) {}
|
||||
tuple(T const&... t) : eso::ESO_t<T...>(t...) {}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct tuple<> {};
|
||||
|
||||
//
|
||||
// make_tuple (value-based implementation)
|
||||
//
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<T...>
|
||||
make_tuple(T const&... t)
|
||||
{
|
||||
return {t...};
|
||||
}
|
||||
|
||||
// Returns the element in the ith position of the tuple
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -226,7 +227,7 @@ decltype(auto)
|
||||
get(tuple<T...> const& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
return eso::getv_cr<I>(t);
|
||||
}
|
||||
|
||||
template <size_t I, class... T>
|
||||
@ -235,7 +236,7 @@ decltype(auto)
|
||||
get(tuple<T...>& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
return eso::getv_r<I>(t);
|
||||
}
|
||||
|
||||
template <size_t I, class... T>
|
||||
@ -244,22 +245,22 @@ decltype(auto)
|
||||
get(tuple<T...>&& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
|
||||
return eso::getv_rr<I>(static_cast<eso::ESO_t<T...>&&>(t));
|
||||
}
|
||||
|
||||
// Returns the position of type X (as a static integer) in the tuple
|
||||
// type's argument list. X must be unique in the argument list.
|
||||
// Returns the first position of type X (as a static integer) in the tuple
|
||||
// type's argument list.
|
||||
template <class X, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find(tuple<T...> const& t) noexcept
|
||||
find(tuple<T...> const&) noexcept
|
||||
{
|
||||
return detail::findt<X, 0>(t);
|
||||
return cute::C<find_true_v<cute::is_same_v<X,T>...>>{};
|
||||
}
|
||||
|
||||
//
|
||||
// Custom is_tuple trait simply checks the existence of tuple_size
|
||||
// and assumes std::get<I>(.), std::tuple_element<I,.>
|
||||
// and assumes get<I>(.), tuple_element<I,.>
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
@ -273,19 +274,7 @@ template <class T>
|
||||
struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {};
|
||||
|
||||
template <class T>
|
||||
constexpr bool is_tuple_v = cute::is_tuple<T>::value;
|
||||
|
||||
//
|
||||
// make_tuple (value-based implementation)
|
||||
//
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<T...>
|
||||
make_tuple(T const&... t)
|
||||
{
|
||||
return {t...};
|
||||
}
|
||||
static constexpr bool is_tuple_v = cute::is_tuple<T>::value;
|
||||
|
||||
//
|
||||
// tuple_cat concatenates multiple cute::tuple into a single cute::tuple,
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -39,11 +40,35 @@ template <class... T>
|
||||
struct type_list {};
|
||||
|
||||
// get<I> for type_list<T...>
|
||||
// requires tuple_element_t<I,type_list<T...>> to have std::is_default_constructible
|
||||
// Get an instance of the Ith type in the pack T...
|
||||
// Requires tuple_element_t<I,type_list<T...>> to have std::is_default_constructible
|
||||
template <size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_STL_NAMESPACE::tuple_element_t<I, type_list<T...>>
|
||||
get(type_list<T...> const& t) noexcept {
|
||||
get(type_list<T...> const&) noexcept {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Find the index of the first true in the pack B...
|
||||
template <bool... B>
|
||||
struct find_true {
|
||||
CUTE_HOST_DEVICE static constexpr size_t find() {
|
||||
size_t i = 0;
|
||||
(void) ((B ? true : (++i, false)) || ...);
|
||||
return i;
|
||||
}
|
||||
static constexpr size_t value = find();
|
||||
};
|
||||
|
||||
template <bool... B>
|
||||
static constexpr size_t find_true_v = find_true<B...>::value;
|
||||
|
||||
// find<X> for type_list<T...>
|
||||
// Finds the first position of type X (as a static integer) in the T... pack
|
||||
template <class X, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_STL_NAMESPACE::integral_constant<size_t, find_true_v<cute::is_same_v<X,T>...>>
|
||||
find(type_list<T...> const&) noexcept {
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -69,9 +94,8 @@ struct tuple_size<cute::type_list<T...>>
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, cute::type_list<T...>>
|
||||
{
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
|
||||
@ -94,9 +118,8 @@ struct tuple_size<cute::type_list<T...>>
|
||||
|
||||
template <size_t I, class... T>
|
||||
struct tuple_element<I, cute::type_list<T...>>
|
||||
{
|
||||
using type = typename CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>::type;
|
||||
};
|
||||
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|
||||
|
||||
@ -834,7 +834,7 @@ coalesce_x(Layout<Shape,Stride> const& layout)
|
||||
} else {
|
||||
return detail::bw_coalesce<R-2>(flat_shape, flat_stride, get<R-1>(flat_shape), get<R-1>(flat_stride));
|
||||
}
|
||||
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
@ -1030,7 +1030,7 @@ template <class LShape, class LStride,
|
||||
class RShape, class RStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
|
||||
composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_stride,
|
||||
RShape const& rhs_shape, RStride const& rhs_stride)
|
||||
{
|
||||
if constexpr (is_tuple<RShape>::value) { // Right-distributivity of Layout composition for RHS tuple
|
||||
@ -1067,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride,
|
||||
auto rest_stride = get<3>(init);
|
||||
|
||||
auto curr_shape = get<curr_i>(lhs_shape);
|
||||
auto curr_stride = get<curr_i>(lhs_stride);
|
||||
[[maybe_unused]] auto curr_stride = get<curr_i>(lhs_stride);
|
||||
|
||||
// Strong divisibility condition -- requires composition to be statically verifiable.
|
||||
//CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition");
|
||||
|
||||
@ -128,8 +128,6 @@ make_fragment_like(ComposedLayout<Swizzle<B,M,S>,Offset,Layout> const& layout)
|
||||
// Utilities
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Get just the Swizzle part of a composed layout.
|
||||
template <int B, int M, int S, class Offset, class LayoutB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
@ -167,8 +165,6 @@ get_nonswizzle_portion(Layout<Shape,Stride> const& slayout)
|
||||
return slayout;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//
|
||||
// Slice a Swizzled ComposedLayout
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user