@ -32,7 +32,7 @@
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
namespace cute
|
||||
|
||||
@ -31,9 +31,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/tensor_impl.hpp>
|
||||
#include <cute/algorithm/fill.hpp>
|
||||
|
||||
namespace cute
|
||||
|
||||
@ -1,51 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE void
|
||||
naive_cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
auto N = size(src);
|
||||
if (tid < N) {
|
||||
uint32_t upper_bound = (N / NumThreads) * NumThreads;
|
||||
CUTE_UNROLL
|
||||
for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds
|
||||
dst[tid + i] = src[tid + i];
|
||||
}
|
||||
if (N % NumThreads != 0) { // Likely static condition
|
||||
uint32_t final_idx = tid + upper_bound;
|
||||
if (final_idx < N) { // Final in-bounds
|
||||
dst[final_idx] = src[final_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE void
|
||||
naive_cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return naive_cooperative_copy(tid, src, dst);
|
||||
}
|
||||
|
||||
// A heuristic to determine a "good" permutation of two tensors for later vectorization and thr-assignment
|
||||
template <class AEngine, class ALayout,
|
||||
class BEngine, class BLayout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
heuristic_permutation(Tensor<AEngine, ALayout> const& a,
|
||||
Tensor<BEngine, BLayout> const& b)
|
||||
{
|
||||
constexpr bool swizzleA = get_swizzle_t<AEngine>::num_bits != 0 or
|
||||
get_swizzle_t<ALayout>::num_bits != 0;
|
||||
constexpr bool swizzleB = get_swizzle_t<BEngine>::num_bits != 0 or
|
||||
get_swizzle_t<BLayout>::num_bits != 0;
|
||||
auto a_inv = right_inverse(get_nonswizzle_portion(a.layout()));
|
||||
auto b_inv = right_inverse(get_nonswizzle_portion(b.layout()));
|
||||
|
||||
constexpr uint8_t scoreA = (uint8_t(swizzleA) << 2) |
|
||||
(uint8_t(is_smem<AEngine>::value) << 1) |
|
||||
(uint8_t(size(a_inv) > size(b_inv)) << 0);
|
||||
|
||||
constexpr uint8_t scoreB = (uint8_t(swizzleB) << 2) |
|
||||
(uint8_t(is_smem<BEngine>::value) << 1) |
|
||||
(uint8_t(size(b_inv) > size(a_inv)) << 0);
|
||||
|
||||
if constexpr (scoreA >= scoreB) {
|
||||
return a_inv;
|
||||
} else {
|
||||
return b_inv;
|
||||
}
|
||||
}
|
||||
|
||||
// cooperative_copy<NumThreads, MaxVecBits>(thr_idx, src, dst)
|
||||
// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits.
|
||||
// Use NumThreads to copy Tensor src to Tensor dst with element-wise vectorization up to MaxVecBits.
|
||||
// @pre 0 <= @a tid < NumThreads
|
||||
// @pre Tensors @a src and @a dst are aligned up to MaxVecBits.
|
||||
// That is, pointers and dynamic strides are assumed to be aligned up to MaxVecBits.
|
||||
//
|
||||
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
@ -56,121 +122,171 @@ cooperative_copy(uint32_t const& tid,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
// Assumes the shapes are static, can generalize
|
||||
// Assumes the shapes are static, can generalize/fallback
|
||||
CUTE_STATIC_ASSERT_V(is_static<decltype(shape(src))>{} && is_static<decltype(shape(dst))>{});
|
||||
CUTE_STATIC_ASSERT_V(size(src) == size(dst));
|
||||
// Assumes the types are the same, can generalize
|
||||
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == sizeof_bits_v<typename DstEngine::value_type>);
|
||||
// Assumes the types are the same, can generalize/fallback
|
||||
static_assert(cute::is_same<typename SrcEngine::value_type, typename DstEngine::value_type>::value);
|
||||
static_assert(MaxVecBits == sizeof_bits_v<typename SrcEngine::value_type> ||
|
||||
MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128,
|
||||
"Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance.");
|
||||
// Check that the tensors are likely shared across threads: either gmem or smem
|
||||
static_assert((is_gmem<SrcEngine>::value || is_smem<SrcEngine>::value),
|
||||
"cooperative_copy expects shared gmem or smem source tensor.");
|
||||
"cooperative_copy expects shared gmem or smem source tensor.");
|
||||
static_assert((is_gmem<DstEngine>::value || is_smem<DstEngine>::value),
|
||||
"cooperative_copy expects shared gmem or smem destination tensor.");
|
||||
|
||||
"cooperative_copy expects shared gmem or smem destination tensor.");
|
||||
// Precondition on tid in DEBUG
|
||||
assert(tid < NumThreads);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<ceil_div(MaxVecBits,8u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<ceil_div(MaxVecBits,8u)>(raw_pointer_cast(dst.data())));
|
||||
|
||||
// Fallback - slow path, naive copy, vectorization disabled
|
||||
if constexpr(size(SrcLayout{}) % NumThreads != 0) {
|
||||
int index = static_cast<int>(tid);
|
||||
CUTE_UNROLL
|
||||
for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) {
|
||||
if(index < size(SrcLayout{})) {
|
||||
dst[index] = src[index];
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy\n");
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
}
|
||||
index += NumThreads;
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// The common layout of the two tensors that can be vectorized over elements and threads
|
||||
// vidx -> coord
|
||||
auto common_layout = heuristic_permutation(src, dst);
|
||||
|
||||
// Apply
|
||||
// (V, rest)
|
||||
Tensor src_a = coalesce(logical_divide(src, common_layout), Shape<_1,_1>{});
|
||||
Tensor dst_a = coalesce(logical_divide(dst, common_layout), Shape<_1,_1>{});
|
||||
|
||||
//
|
||||
// Determine vectorization of elems and thrs based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
// The number of elements and number of bits
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
constexpr int total_elem = size(SrcLayout{});
|
||||
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src_a, dst_a))::value;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("src_a: "); print(src_a); print("\n");
|
||||
print(" "); print("dst_a: "); print(dst_a); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
//
|
||||
if constexpr (total_elem % NumThreads != 0) {
|
||||
// Not attempting to find a partitioning pattern, fallback to dynamically indexed slowpath
|
||||
|
||||
if constexpr (common_elem > 1 && MaxVecBits > elem_bits) {
|
||||
// If the vectorization is non-trivial and divides the maximum vectorizations, then vectorize
|
||||
constexpr auto max_align_src = elem_bits * decltype(max_alignment(src_a.layout()))::value;
|
||||
constexpr auto max_align_dst = elem_bits * decltype(max_alignment(dst_a.layout()))::value;
|
||||
constexpr auto vec_bits = gcd(max_align_src, max_align_dst, MaxVecBits);
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
|
||||
static_assert(vec_bits % elem_bits == 0, "Expected divisibility");
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
|
||||
Tensor src_v = recast<VecType const>(src_a);
|
||||
Tensor dst_v = recast<VecType >(dst_a);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy -- naive\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
naive_cooperative_copy<NumThreads>(tid, src_v, dst_v);
|
||||
} else {
|
||||
naive_cooperative_copy<NumThreads>(tid, src_a, dst_a);
|
||||
}
|
||||
} else {
|
||||
// Fast path with vectorization
|
||||
// If the tensors can be equally partitioned by the threads,
|
||||
// compute vectorization widths in elements and threads.
|
||||
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
// If there are too many threads to allow a full vectorized copy, trunc the vectorization
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
static_assert(vec_thrs <= NumThreads);
|
||||
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
constexpr int max_bits_per_thr = total_bits / NumThreads;
|
||||
// At least elem_bits, at most common_bits
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int vec_bits = cute::max(elem_bits, cute::gcd(common_bits, int(MaxVecBits), max_bits_per_thr));
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
static_assert(vec_bits % elem_bits == 0, "Expected divisibility");
|
||||
static_assert(vec_bits >= 8, "No support for subbyte copying");
|
||||
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
|
||||
constexpr int vec_thrs = cute::min(int(NumThreads), total_elem / vec_elem);
|
||||
|
||||
//
|
||||
// Determine the partitioning patterns for the vec_elems and vec_thrs
|
||||
//
|
||||
|
||||
// Distribute the rest of the V*T to some consistent portion outside of the common_layout, if needed
|
||||
auto common_domain_src = domain_distribute(shape(src_a), Int<vec_elem*vec_thrs>{});
|
||||
auto common_domain_dst = domain_distribute(shape(dst_a), Int<vec_elem*vec_thrs>{});
|
||||
|
||||
// Make sure for now, could fall back here instead
|
||||
CUTE_STATIC_ASSERT_V(size(common_domain_src) == Int<vec_elem*vec_thrs>{});
|
||||
CUTE_STATIC_ASSERT_V(compatible(common_domain_src, common_domain_dst) ||
|
||||
compatible(common_domain_dst, common_domain_src));
|
||||
// Use the "more specific" domain for the extra elements of V*T
|
||||
auto common_domain = conditional_return(compatible(common_domain_src, common_domain_dst),
|
||||
common_domain_dst, common_domain_src);
|
||||
|
||||
// Construct the tiler
|
||||
auto tiler_vt = common_domain.with_shape(Int<vec_elem>{}, Int<vec_thrs>{});
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = logical_divide(src_a, tiler_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = logical_divide(dst_a, tiler_vt)(make_coord(_,tid),_);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy -- vec\n");
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy -- vec\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("tiler_vt: "); print(tiler_vt); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType >(dst_v): "); print(recast<VecType >(dst_v)); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
// If we're using all threads (static) or the tid is in-range (dynamic)
|
||||
if (vec_thrs == NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default max-vectorization size to value_type size
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
@ -184,7 +300,10 @@ cooperative_copy(uint32_t const& tid,
|
||||
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
||||
}
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <uint32_t NumThreads,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
@ -197,9 +316,7 @@ cooperative_copy(uint32_t const& tid,
|
||||
return cooperative_copy<NumThreads>(tid, src, dst);
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <uint32_t NumThreads,
|
||||
uint32_t MaxVecBits,
|
||||
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
|
||||
@ -39,7 +39,7 @@
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -76,29 +76,15 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X to satisfy MMA partitioning safety
|
||||
Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K)));
|
||||
Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K)));
|
||||
Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N)));
|
||||
// Partition the sA, sB, and sC tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
@ -109,9 +95,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
print(" sA: "); print( sA); print("\n");
|
||||
print(" sB: "); print( sB); print("\n");
|
||||
print(" sC: "); print( sC); print("\n");
|
||||
print("r_sA: "); print(rounded_sA); print("\n");
|
||||
print("r_sB: "); print(rounded_sB); print("\n");
|
||||
print("r_sC: "); print(rounded_sC); print("\n");
|
||||
print(thr_mma);
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
@ -127,8 +110,8 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k)
|
||||
Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k)
|
||||
Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k)
|
||||
Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k)
|
||||
@ -222,7 +205,7 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n)
|
||||
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
|
||||
#include <cute/container/alignment.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
@ -199,14 +199,14 @@ copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
{
|
||||
static_assert(sizeof_bits_v<VecType> >= 8 && sizeof_bits_v<VecType> % 8 == 0,
|
||||
"Expected a vectorization type of at least a byte.");
|
||||
using SrcType = typename SrcEngine::element_type;
|
||||
using DstType = typename DstEngine::element_type;
|
||||
if constexpr (sizeof_bits_v<SrcType> == sizeof_bits_v<DstType> &&
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
if constexpr (cute::is_same<SrcType, DstType>::value &&
|
||||
sizeof_bits_v<VecType> > sizeof_bits_v<DstType>)
|
||||
{
|
||||
// Preserve volatility of Src/Dst types.
|
||||
using SrcVecType = conditional_t<is_volatile_v<SrcType>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<DstType>, VecType volatile, VecType >;
|
||||
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
||||
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
||||
Tensor src_v = recast<SrcVecType>(src);
|
||||
Tensor dst_v = recast<DstVecType>(dst);
|
||||
|
||||
@ -264,22 +264,22 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
|
||||
{
|
||||
constexpr int vec_elem = decltype(max_common_vector(src, dst))::value;
|
||||
|
||||
constexpr int src_bits = sizeof_bits<typename SrcEngine::value_type>::value;
|
||||
// When layouts are static, accept vec_bits up to 128
|
||||
// When layouts are dynamic, accept vec_bits up to MaxVecBits
|
||||
constexpr int vec_bits = (is_static<SrcLayout>::value && is_static<DstLayout>::value) ?
|
||||
cute::min(vec_elem * src_bits, 128) :
|
||||
cute::min(vec_elem * src_bits, MaxVecBits);
|
||||
constexpr int max_align_src = decltype(max_alignment(src.layout()))::value;
|
||||
constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value;
|
||||
constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits);
|
||||
print(" "); print(src); print("\n");
|
||||
print(" "); print(dst); print("\n");
|
||||
}
|
||||
#endif
|
||||
constexpr int src_bits = sizeof_bits<typename SrcEngine::value_type>::value;
|
||||
constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits);
|
||||
|
||||
if constexpr (vec_elem > 1 && vec_bits >= 8) {
|
||||
// If more than one element vectorizes to 8bits or more, then copy_vec
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits);
|
||||
print(" "); print(src); print("\n");
|
||||
print(" "); print(dst); print("\n");
|
||||
}
|
||||
#endif
|
||||
return copy_vec<uint_bit_t<vec_bits>>(src, dst);
|
||||
} else {
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
@ -294,10 +294,16 @@ void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(AutoVectorizingCopy{}, src, dst);
|
||||
if constexpr (is_static<SrcLayout>::value && is_static<DstLayout>::value) {
|
||||
// Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
|
||||
} else {
|
||||
// Do not assume that dynamic layouts are aligned.
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-vectorizing copy with assumed alignment of dynamic layout strides up to 128bit.
|
||||
// Auto-vectorizing copy with assumed alignment up to 128bit.
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
@ -308,19 +314,6 @@ copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
|
||||
}
|
||||
|
||||
// Specializaton for Atom AutoVectorizingCopy
|
||||
template <class... Args,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<AutoVectorizingCopy, Args...> const&,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(AutoVectorizingCopy{}, src, dst);
|
||||
}
|
||||
|
||||
// Specializaton for Atom AutoVectorizingCopyAssumedAlignment
|
||||
template <int MaxVecBits, class... Args,
|
||||
class SrcEngine, class SrcLayout,
|
||||
@ -346,7 +339,7 @@ copy(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom, // Copy_Traits m
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
static_assert(sizeof_bits<SrcType>::value == sizeof_bits<DstType>::value);
|
||||
static_assert(cute::is_same<SrcType, DstType>::value);
|
||||
static_assert((is_gmem<SrcEngine>::value && is_smem<DstEngine>::value) ||
|
||||
(is_smem<SrcEngine>::value && is_gmem<DstEngine>::value),
|
||||
"Bulk Copy only supports gmem -> smem or smem -> gmem movement.");
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
#include <cute/algorithm/prefer.hpp>
|
||||
|
||||
namespace cute
|
||||
|
||||
@ -35,7 +35,7 @@
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
@ -90,12 +90,6 @@ constexpr bool has_prefetch = false;
|
||||
template <class CopyOp>
|
||||
constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
|
||||
|
||||
template <class CopyOp, class = void>
|
||||
constexpr bool is_prefetch = false;
|
||||
|
||||
template <class CopyOp>
|
||||
constexpr bool is_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = is_same_v<CopyOp, typename CopyOp::PREFETCH>;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class CopyOp, class... CT_Args, class... CA_Args,
|
||||
|
||||
@ -33,8 +33,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -100,13 +99,13 @@ transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
}
|
||||
|
||||
// Similar to std::transform transforms one tensors and assigns it to another
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut,
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut,
|
||||
class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
@ -117,30 +116,30 @@ transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn, class LayoutIn,
|
||||
class EngineOut, class LayoutOut,
|
||||
class EngineOut, class LayoutOut,
|
||||
class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> && tensor_out,
|
||||
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
|
||||
Tensor<EngineOut,LayoutOut> && tensor_out,
|
||||
UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor_in, tensor_out, op);
|
||||
}
|
||||
|
||||
// Similar to std::transform with a binary operation
|
||||
// Takes two tensors as input and one tensor as output.
|
||||
// Takes two tensors as input and one tensor as output.
|
||||
// Applies the binary_op to tensor_in1 and tensor_in2 and
|
||||
// assigns it to tensor_out
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut,
|
||||
class EngineOut, class LayoutOut,
|
||||
class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
Tensor<EngineOut,LayoutOut> & tensor_out,
|
||||
BinaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
@ -152,11 +151,11 @@ transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
// Accept mutable temporaries
|
||||
template <class EngineIn1, class LayoutIn1,
|
||||
class EngineIn2, class LayoutIn2,
|
||||
class EngineOut, class LayoutOut,
|
||||
class EngineOut, class LayoutOut,
|
||||
class BinaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
||||
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
|
||||
Tensor<EngineOut,LayoutOut> && tensor_out,
|
||||
BinaryOp&& op)
|
||||
|
||||
@ -404,29 +404,54 @@ namespace detail {
|
||||
// This impl compiles much faster than cute::apply and variadic args
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<>)
|
||||
auto
|
||||
fold(T&&, V&& v, F&&, seq<>)
|
||||
{
|
||||
return static_cast<V&&>(v);
|
||||
return v;
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
template <class T, class V, class F, int I0>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<I,Is...>)
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return fold(static_cast<T&&>(t),
|
||||
f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t))),
|
||||
f,
|
||||
seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
return f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1>)
|
||||
{
|
||||
return f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2>)
|
||||
{
|
||||
return f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2, int I3>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3>)
|
||||
{
|
||||
return f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t)));
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I0, int I1, int I2, int I3, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f, seq<I0,I1,I2,I3,Is...>)
|
||||
{
|
||||
return fold(static_cast<T&&>(t),
|
||||
f(f(f(f(static_cast<V&&>(v), get<I0>(static_cast<T&&>(t))), get<I1>(static_cast<T&&>(t))), get<I2>(static_cast<T&&>(t))), get<I3>(static_cast<T&&>(t))),
|
||||
f,
|
||||
seq<Is...>{});
|
||||
}
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
@ -448,7 +473,7 @@ fold(T&& t, V&& v, F&& f)
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
auto
|
||||
fold_first(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
@ -457,7 +482,7 @@ fold_first(T&& t, F&& f)
|
||||
f,
|
||||
make_range<1,tuple_size<remove_cvref_t<T>>::value>{});
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
@ -701,7 +726,14 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace(T const& t, X const& x)
|
||||
{
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
} else {
|
||||
static_assert(N == 0);
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Replace the first element of the tuple with x
|
||||
@ -1077,9 +1109,9 @@ zip2_by(T const& t, TG const& guide)
|
||||
|
||||
/// @return A tuple of the elements of @c t in reverse order.
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
reverse(T const& t)
|
||||
reverse(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
|
||||
|
||||
Reference in New Issue
Block a user