CUTLASS 3.0.0
This commit is contained in:
79
include/cute/algorithm/axpby.hpp
Normal file
79
include/cute/algorithm/axpby.hpp
Normal file
@ -0,0 +1,79 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> && y)
|
||||
{
|
||||
return axpby(alpha, x, beta, y);
|
||||
}
|
||||
|
||||
//
|
||||
// AXPBY
|
||||
//
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> & y)
|
||||
{
|
||||
auto isBetaZero = (beta == Int<0>{});
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(x); ++i) {
|
||||
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
66
include/cute/algorithm/clear.hpp
Normal file
66
include/cute/algorithm/clear.hpp
Normal file
@ -0,0 +1,66 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/algorithm/fill.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
clear(Tensor<Engine, Layout>&& tensor)
|
||||
{
|
||||
return clear(tensor);
|
||||
}
|
||||
|
||||
//
|
||||
// Set elements to zero
|
||||
//
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
clear(Tensor<Engine, Layout>& tensor)
|
||||
{
|
||||
using T = typename Tensor<Engine,Layout>::value_type;
|
||||
|
||||
fill(tensor, T{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
262
include/cute/algorithm/copy.hpp
Normal file
262
include/cute/algorithm/copy.hpp
Normal file
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(pred, src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(copy_atom, pred, src, dst);
|
||||
}
|
||||
|
||||
template <class VecType,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_vec<VecType>(src, dst);
|
||||
}
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(copy_atom, src, dst);
|
||||
}
|
||||
|
||||
//
|
||||
// copy_if -- Predicated Copy
|
||||
//
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
auto copy_op = select_elementwise_copy(src, dst);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(src); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_op.copy(src(i), dst(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy_if -- Predicated CopyAtom
|
||||
//
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PredTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
PredTensor const& pred, // (Rest...)
|
||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||
{
|
||||
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
||||
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
||||
copy_atom.call(src, dst);
|
||||
} else { // Loop over all but the first mode
|
||||
constexpr int R = SrcLayout::rank;
|
||||
auto src_v = group_modes<1,R>(src);
|
||||
auto dst_v = group_modes<1,R>(dst);
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy_vec -- attempt vectorized copy with VecType
|
||||
//
|
||||
|
||||
template <class VecType,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
|
||||
{
|
||||
/* @pre is_aligned<N>(src.data()) &&
|
||||
* is_aligned<N>(dst.data())
|
||||
*/
|
||||
auto src_v = recast<VecType const>(src);
|
||||
auto dst_v = recast<VecType >(dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType)));
|
||||
print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n");
|
||||
print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
||||
} else {
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType)));
|
||||
print(" "); print(layout(src)); print("\n");
|
||||
print(" "); print(layout(dst)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- auto-vectorizing copy
|
||||
//
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
constexpr int N = decltype(max_common_vector(src, dst))::value;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found a max_common_vector of %d\n", N);
|
||||
print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n");
|
||||
print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (N <= 1) {
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
} else {
|
||||
constexpr int vec_bits = N * sizeof_bits<typename SrcEngine::value_type>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
return copy_vec<VecType>(src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- CopyAtom
|
||||
//
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy_if(copy_atom, TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<DefaultCopy, CopyArgs...> const&,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
87
include/cute/algorithm/fill.hpp
Normal file
87
include/cute/algorithm/fill.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/prefer.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>&& tensor, T const& value)
|
||||
{
|
||||
return fill(tensor, value);
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Prefer fill(tensor.data(), value), if possible
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<1>)
|
||||
-> decltype(fill(tensor.data(), value))
|
||||
{
|
||||
fill(tensor.data(), value);
|
||||
}
|
||||
|
||||
// Default implementation
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<0>)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = value;
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value)
|
||||
{
|
||||
return detail::fill(tensor, value, prefer<1>{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
198
include/cute/algorithm/functional.hpp
Normal file
198
include/cute/algorithm/functional.hpp
Normal file
@ -0,0 +1,198 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
/** C++14 <functional> extensions */
|
||||
|
||||
namespace cute {
|
||||
|
||||
/**************/
|
||||
/** Identity **/
|
||||
/**************/
|
||||
|
||||
struct identity {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <class R>
|
||||
struct constant_fn {
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&&...) const {
|
||||
return r_;
|
||||
}
|
||||
R r_;
|
||||
};
|
||||
|
||||
/***********/
|
||||
/** Unary **/
|
||||
/***********/
|
||||
|
||||
#define CUTE_LEFT_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP std::forward<T>(arg); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return std::forward<T>(arg) OP ; \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP (std::forward<T>(arg)); \
|
||||
} \
|
||||
}
|
||||
|
||||
CUTE_LEFT_UNARY_OP(unary_plus, +);
|
||||
CUTE_LEFT_UNARY_OP(negate, -);
|
||||
CUTE_LEFT_UNARY_OP(bit_not, ~);
|
||||
CUTE_LEFT_UNARY_OP(logical_not, !);
|
||||
CUTE_LEFT_UNARY_OP(dereference, *);
|
||||
CUTE_LEFT_UNARY_OP(address_of, &);
|
||||
CUTE_LEFT_UNARY_OP(pre_increment, ++);
|
||||
CUTE_LEFT_UNARY_OP(pre_decrement, --);
|
||||
|
||||
CUTE_RIGHT_UNARY_OP(post_increment, ++);
|
||||
CUTE_RIGHT_UNARY_OP(post_decrement, --);
|
||||
|
||||
CUTE_NAMED_UNARY_OP(abs_fn, abs);
|
||||
CUTE_NAMED_UNARY_OP(conjugate, cute::conj);
|
||||
|
||||
#undef CUTE_LEFT_UNARY_OP
|
||||
#undef CUTE_RIGHT_UNARY_OP
|
||||
#undef CUTE_NAMED_UNARY_OP
|
||||
|
||||
/************/
|
||||
/** Binary **/
|
||||
/************/
|
||||
|
||||
#define CUTE_BINARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
CUTE_BINARY_OP(plus, +);
|
||||
CUTE_BINARY_OP(minus, -);
|
||||
CUTE_BINARY_OP(multiplies, *);
|
||||
CUTE_BINARY_OP(divides, /);
|
||||
CUTE_BINARY_OP(modulus, %);
|
||||
|
||||
CUTE_BINARY_OP(plus_assign, +=);
|
||||
CUTE_BINARY_OP(minus_assign, -=);
|
||||
CUTE_BINARY_OP(multiplies_assign, *=);
|
||||
CUTE_BINARY_OP(divides_assign, /=);
|
||||
CUTE_BINARY_OP(modulus_assign, %=);
|
||||
|
||||
CUTE_BINARY_OP(bit_and, &);
|
||||
CUTE_BINARY_OP(bit_or, |);
|
||||
CUTE_BINARY_OP(bit_xor, ^);
|
||||
CUTE_BINARY_OP(left_shift, <<);
|
||||
CUTE_BINARY_OP(right_shift, >>);
|
||||
|
||||
CUTE_BINARY_OP(bit_and_assign, &=);
|
||||
CUTE_BINARY_OP(bit_or_assign, |=);
|
||||
CUTE_BINARY_OP(bit_xor_assign, ^=);
|
||||
CUTE_BINARY_OP(left_shift_assign, <<=);
|
||||
CUTE_BINARY_OP(right_shift_assign, >>=);
|
||||
|
||||
CUTE_BINARY_OP(logical_and, &&);
|
||||
CUTE_BINARY_OP(logical_or, ||);
|
||||
|
||||
CUTE_BINARY_OP(equal_to, ==);
|
||||
CUTE_BINARY_OP(not_equal_to, !=);
|
||||
CUTE_BINARY_OP(greater, >);
|
||||
CUTE_BINARY_OP(less, <);
|
||||
CUTE_BINARY_OP(greater_equal, >=);
|
||||
CUTE_BINARY_OP(less_equal, <=);
|
||||
|
||||
CUTE_NAMED_BINARY_OP(max_fn, cute::max);
|
||||
CUTE_NAMED_BINARY_OP(min_fn, cute::min);
|
||||
|
||||
#undef CUTE_BINARY_OP
|
||||
#undef CUTE_NAMED_BINARY_OP
|
||||
|
||||
/**********/
|
||||
/** Meta **/
|
||||
/**********/
|
||||
|
||||
template <class Fn, class Arg>
|
||||
struct bound_fn {
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(T&& arg) {
|
||||
return fn_(arg_, std::forward<T>(arg));
|
||||
}
|
||||
|
||||
Fn fn_;
|
||||
Arg arg_;
|
||||
};
|
||||
|
||||
template <class Fn, class Arg>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
bind(Fn const& fn, Arg const& arg) {
|
||||
return bound_fn<Fn,Arg>{fn, arg};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
718
include/cute/algorithm/gemm.hpp
Normal file
718
include/cute/algorithm/gemm.hpp
Normal file
@ -0,0 +1,718 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
/** The gemm algorithm takes four (or three) tensors and computes
|
||||
* D += A * B + C
|
||||
* It dispatches based on the number of modes each tensor has:
|
||||
*
|
||||
* 1. `(V) x (V) => (V)`.
|
||||
* The element-wise product of vectors. Dispatches to FMA or MMA.
|
||||
* 2. `(M) x (N) => (M,N)`.
|
||||
* The outer product of vectors. Dispatches to [3] with new mode K=(1).
|
||||
* 3. `(M,K) x (N,K) => (M,N)`.
|
||||
* The product of matrices. Dispatches to [5] with MMA vector-mode V.
|
||||
* 4. `(V,M) x (V,N) => (V,M,N)`.
|
||||
* The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n).
|
||||
* 5. `(V,M,K) x (V,N,K) => (V,M,N)`.
|
||||
* The batched product of matrices. Dispatches to [4] for each (k).
|
||||
*/
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Three arguments to four
|
||||
//
|
||||
|
||||
template <class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> & C)
|
||||
{
|
||||
return gemm(C, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> & C)
|
||||
{
|
||||
return gemm(mma, C, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> && C)
|
||||
{
|
||||
return gemm(C, A, B, C);
|
||||
}
|
||||
|
||||
template <class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TD, DLayout> && D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
return gemm(D, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> && C)
|
||||
{
|
||||
return gemm(mma, C, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> && D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
return gemm(mma, D, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Default MMA is UniversalFMA
|
||||
//
|
||||
|
||||
template <class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TD, DLayout> & D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
using MMA = MMA_Atom<UniversalFMA<typename Tensor<TD,DLayout>::value_type,
|
||||
typename Tensor<TA,ALayout>::value_type,
|
||||
typename Tensor<TB,BLayout>::value_type,
|
||||
typename Tensor<TC,CLayout>::value_type>>;
|
||||
|
||||
return gemm(MMA{}, D, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Thread-Local Register-Memory GEMMs
|
||||
//
|
||||
|
||||
// Dispatch [1]: (V) x (V) => (V)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 1 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 1 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 1 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 1 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V) Logical data
|
||||
{
|
||||
// No static assertions on (V), MMA checks compatibility
|
||||
mma.call(D, A, B, C);
|
||||
}
|
||||
|
||||
// Dispatch [2]: (M) x (N) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 1 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 1 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
gemm(mma,
|
||||
D, // (M,N)
|
||||
make_tensor(A.data(), append<2>(A.layout())), // (M,1)
|
||||
make_tensor(B.data(), append<2>(B.layout())), // (N,1)
|
||||
C); // (M,N)
|
||||
}
|
||||
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 2 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
// Assert this is a 1-value MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
|
||||
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
|
||||
}
|
||||
|
||||
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 2 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
// REGISTER .reuse OPTIMIZATIONS
|
||||
|
||||
auto M = size<1>(A);
|
||||
auto N = size<1>(B);
|
||||
|
||||
// 64-bit traversal specialization -- serpentine path
|
||||
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 8 &&
|
||||
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 8)
|
||||
{
|
||||
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
|
||||
// Row-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate
|
||||
gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns));
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
|
||||
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
|
||||
// 32-bit traversal specialization -- kinked serpentine path
|
||||
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 4 &&
|
||||
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 4)
|
||||
{
|
||||
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
|
||||
// Row-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; m += 2) {
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int ns = (m & 2) ? N-1-n : n;
|
||||
gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns));
|
||||
|
||||
if (m+1 < M) {
|
||||
gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; n += 2) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
// Kinked serpentine traversal for maximum register reuse
|
||||
int ms = (n & 2) ? M-1-m : m;
|
||||
gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0));
|
||||
|
||||
if (n+1 < N) {
|
||||
gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
// Fallback to serpentine loop
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
|
||||
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 3 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 3 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto K = size<2>(A);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < K; ++k) {
|
||||
gemm(mma, D, A(_,_,k), B(_,_,k), C);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Thread-Local Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
// Dispatch [1]: (V) x (V) => (V)
|
||||
// Dispatch [2]: (M) x (N) => (M,N)
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
// Assert this is a 1-value MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
|
||||
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
|
||||
}
|
||||
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 3 && is_smem<TA>::value &&
|
||||
BLayout::rank == 3 && is_smem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto rA = MMA_Atom<MMA>::make_fragment_A(A);
|
||||
auto rB = MMA_Atom<MMA>::make_fragment_B(B);
|
||||
|
||||
auto K = size<2>(A);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
copy(A(_,_,k), rA(_,_,k));
|
||||
copy(B(_,_,k), rB(_,_,k));
|
||||
// Thread-level register gemm for k
|
||||
gemm(mma, D, rA(_,_,k), rB(_,_,k), C);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Collective Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept and return value of type TA::value_type");
|
||||
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept and return value of type TB::value_type");
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
// Compute the "residues"
|
||||
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
|
||||
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
|
||||
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
|
||||
|
||||
// Shift the origin so k_residue is zeroth tile
|
||||
sA.data() = &sA(0,k_residue);
|
||||
sB.data() = &sB(0,k_residue);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
|
||||
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
|
||||
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X
|
||||
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(rounded_sA.layout()); print("\n");
|
||||
print(rounded_sB.layout()); print("\n");
|
||||
print(rounded_sC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(tCsA.layout()); print("\n");
|
||||
print(tCsB.layout()); print("\n");
|
||||
print(tCsC.layout()); print("\n");
|
||||
print(tCrA.layout()); print("\n");
|
||||
print(tCrB.layout()); print("\n");
|
||||
print(tCrC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREDICATION
|
||||
//
|
||||
|
||||
// Allocate the preds for only the MMA-mode of tCsA and tCsB
|
||||
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
|
||||
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
|
||||
|
||||
// Create coordinate tensors on a single compute block for predication
|
||||
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
|
||||
|
||||
// Populate the m and n predicates
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
|
||||
threadIdx.x,
|
||||
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
|
||||
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0 (with k-predication)
|
||||
//
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
|
||||
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
|
||||
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
// static-if load the next k_block. No k-predication required on these loads.
|
||||
if (k_block < K_BLOCK_MAX-1)
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
|
||||
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
|
||||
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsC); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<2>(tCsC); ++n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsC); ++i)
|
||||
{
|
||||
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
|
||||
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
|
||||
{
|
||||
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
46
include/cute/algorithm/prefer.hpp
Normal file
46
include/cute/algorithm/prefer.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// Infinite types that inherit from each other
|
||||
template <std::size_t N>
|
||||
struct prefer : prefer<N-1> {};
|
||||
|
||||
template <>
|
||||
struct prefer<0> {};
|
||||
|
||||
// Can be used to preferencially overload implementations
|
||||
// Higher N in prefer<N> have higher priority.
|
||||
|
||||
} // end namespace cute
|
||||
102
include/cute/algorithm/tensor_algorithms.hpp
Normal file
102
include/cute/algorithm/tensor_algorithms.hpp
Normal file
@ -0,0 +1,102 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/** Common algorithms on (hierarchical) tensors */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// for_each
|
||||
//
|
||||
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return for_each(tensor, static_cast<UnaryOp&&>(op));
|
||||
}
|
||||
|
||||
//
|
||||
// transform
|
||||
//
|
||||
|
||||
// Similar to std::transform but does not return number of elements affected
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor, std::forward<UnaryOp>(op));
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
846
include/cute/algorithm/tuple_algorithms.hpp
Normal file
846
include/cute/algorithm/tuple_algorithms.hpp
Normal file
@ -0,0 +1,846 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
/** Common algorithms on (hierarchical) tuples */
|
||||
/** Style choice:
|
||||
* Forward params [using static_cast<T&&>(.)] for const/non-const/ref/non-ref args
|
||||
* but don't bother forwarding functions as ref-qualified member fns are extremely rare
|
||||
*/
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Apply (Unpack)
|
||||
// (t, f) => f(t_0,t_1,...,t_n)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
apply(T&& t, F&& f, seq<I...>)
|
||||
{
|
||||
return f(get<I>(static_cast<T&&>(t))...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
apply(T&& t, F&& f)
|
||||
{
|
||||
return detail::apply(static_cast<T&&>(t), f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Transform Apply
|
||||
// (t, f, g) => g(f(t_0),f(t_1),...)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T&& t, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T&&>(t)))...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T0&&>(t0)),
|
||||
get<I>(static_cast<T1&&>(t1)))...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T0&&>(t0)),
|
||||
get<I>(static_cast<T1&&>(t1)),
|
||||
get<I>(static_cast<T2&&>(t2)))...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T&& t, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
//
|
||||
// For Each
|
||||
// (t, f) => f(t_0),f(t_1),...,f(t_n)
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(T&& t, F&& f)
|
||||
{
|
||||
detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
for_each_leaf(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::apply(static_cast<T&&>(t), [&](auto&&... a){ return (for_each_leaf(static_cast<decltype(a)&&>(a), f), ...); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Transform
|
||||
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T const& t, F&& f)
|
||||
{
|
||||
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_leaf(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return transform(t, [&](auto const& a) { return transform_leaf(a, f); });
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// find and find_if
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<>)
|
||||
{
|
||||
return cute::integral_constant<int, tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (decltype(f(get<I>(t)))::value) {
|
||||
return cute::integral_constant<int, I>{};
|
||||
} else {
|
||||
return find_if(t, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::find_if(t, f, tuple_seq<T>{});
|
||||
} else {
|
||||
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find(T const& t, X const& x)
|
||||
{
|
||||
return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
none_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == std::tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
auto not_f = [&](auto const& a) { return !f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == std::tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, !decltype(none_of(t, f))::value>{};
|
||||
}
|
||||
|
||||
//
|
||||
// Filter
|
||||
// (t, f) => <f(t_0),f(t_1),...,f(t_n)>
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T const& t, F&& f)
|
||||
{
|
||||
return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
//
|
||||
// Fold (Reduce, Accumulate)
|
||||
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// This impl compiles much faster than cute::apply and variadic args
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<>)
|
||||
{
|
||||
return static_cast<V&&>(v);
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return fold(static_cast<T&&>(t),
|
||||
f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t))),
|
||||
f,
|
||||
seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
static_cast<V&&>(v),
|
||||
f,
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<V&&>(v), static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold_first(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
get<0>(static_cast<T&&>(t)),
|
||||
f,
|
||||
make_range<1,std::tuple_size<std::remove_reference_t<T>>::value>{});
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// front, back, take, unwrap
|
||||
//
|
||||
|
||||
// Get the first non-tuple element in a hierarchical tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
front(T&& t)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return front(get<0>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Get the last non-tuple element in a hierarchical tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
back(T&& t)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
constexpr int N = tuple_size<remove_cvref_t<T>>::value;
|
||||
return back(get<N-1>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Takes the elements in the range [B,E)
|
||||
template <int B, int E, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
take(T const& t)
|
||||
{
|
||||
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
|
||||
}
|
||||
|
||||
// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unwrap(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (tuple_size<T>::value == 1) {
|
||||
return unwrap(get<0>(t));
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// insert and remove and replace
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Shortcut around tuple_cat for common insert/remove/repeat cases
|
||||
template <class T, class X, int... I, int... J, int... K>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
construct(T const& t, X const& x, seq<I...>, seq<J...>, seq<K...>)
|
||||
{
|
||||
return cute::make_tuple(get<I>(t)..., (void(J),x)..., get<K>(t)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Insert x into the Nth position of the tuple
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
insert(T const& t, X const& x)
|
||||
{
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Remove the Nth element of the tuple
|
||||
template <int N, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
remove(T const& t)
|
||||
{
|
||||
return detail::construct(t, 0, make_seq<N>{}, seq<>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Replace the Nth element of the tuple with x
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace(T const& t, X const& x)
|
||||
{
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Replace the first element of the tuple with x
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace_front(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size<T>::value>{});
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Replace the last element of the tuple with x
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace_back(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(t, x, make_seq<tuple_size<T>::value-1>{}, seq<0>{}, seq<>{});
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs of tuple_size N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
repeat(X const& x)
|
||||
{
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs the same profile as tuple
|
||||
//
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
repeat_like(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return transform(t, [&](auto const& a) { return repeat_like(a,x); });
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Group the elements [B,E) of a T into a single element
|
||||
// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{})
|
||||
// => T<_1,_2,T<_3,_4>,_5,_6>{}
|
||||
template <int B, int E, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
group(T const& t)
|
||||
{
|
||||
return detail::construct(t, take<B,E>(t), make_seq<B>{}, seq<0>{}, make_range<E,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Extend a T to rank N by appending/prepending an element
|
||||
//
|
||||
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
append(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (N == tuple_size<T>::value) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > tuple_size<T>::value);
|
||||
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, make_seq<N-tuple_size<T>::value>{}, seq<>{});
|
||||
}
|
||||
} else {
|
||||
if constexpr (N == 1) {
|
||||
return a;
|
||||
} else {
|
||||
return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq<N-1>{}, seq<>{});
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
append(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, seq<0>{}, seq<>{});
|
||||
} else {
|
||||
return cute::make_tuple(a, x);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
prepend(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (N == tuple_size<T>::value) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > tuple_size<T>::value);
|
||||
return detail::construct(a, x, seq<>{}, make_seq<N-tuple_size<T>::value>{}, make_seq<tuple_size<T>::value>{});
|
||||
}
|
||||
} else {
|
||||
if constexpr (N == 1) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > 1);
|
||||
return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq<N-1>{}, seq<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
prepend(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq<tuple_size<T>::value>{});
|
||||
} else {
|
||||
return cute::make_tuple(x, a);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Inclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
iscan(T const& t, V const& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
// Apply the function to v and the element at I
|
||||
auto v_next = f(v, get<I>(t));
|
||||
// Replace I with v_next
|
||||
auto t_next = replace<I>(t, v_next);
|
||||
|
||||
#if 0
|
||||
std::cout << "ISCAN i" << I << std::endl;
|
||||
std::cout << " t " << t << std::endl;
|
||||
std::cout << " i " << v << std::endl;
|
||||
std::cout << " f(i,t) " << v_next << std::endl;
|
||||
std::cout << " t_n " << t_next << std::endl;
|
||||
#endif
|
||||
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return t_next;
|
||||
} else {
|
||||
return iscan(t_next, v_next, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
iscan(T const& t, V const& v, F&& f)
|
||||
{
|
||||
return detail::iscan(t, v, f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Exclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
escan(T const& t, V const& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
// Replace I with v
|
||||
return replace<I>(t, v);
|
||||
} else {
|
||||
// Apply the function to v and the element at I
|
||||
auto v_next = f(v, get<I>(t));
|
||||
// Replace I with v
|
||||
auto t_next = replace<I>(t, v);
|
||||
|
||||
#if 0
|
||||
std::cout << "ESCAN i" << I << std::endl;
|
||||
std::cout << " t " << t << std::endl;
|
||||
std::cout << " i " << v << std::endl;
|
||||
std::cout << " f(i,t) " << v_next << std::endl;
|
||||
std::cout << " t_n " << t_next << std::endl;
|
||||
#endif
|
||||
|
||||
// Recurse
|
||||
return escan(t_next, v_next, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
escan(T const& t, V const& v, F&& f)
|
||||
{
|
||||
return detail::escan(t, v, f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Zip (Transpose)
|
||||
//
|
||||
|
||||
// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input
|
||||
// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int J, class T, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip_(T const& t, seq<Is...>)
|
||||
{
|
||||
return cute::make_tuple(get<J>(get<Is>(t))...);
|
||||
}
|
||||
|
||||
template <class T, int... Is, int... Js>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T const& t, seq<Is...>, seq<Js...>)
|
||||
{
|
||||
static_assert(conjunction<bool_constant<tuple_size<tuple_element_t<0,T>>::value == tuple_size<tuple_element_t<Is,T>>::value>...>::value, "Mismatched Ranks");
|
||||
return cute::make_tuple(detail::zip_<Js>(t, seq<Is...>{})...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_tuple<tuple_element_t<0,T>>::value) {
|
||||
return detail::zip(t, tuple_seq<T>{}, tuple_seq<tuple_element_t<0,T>>{});
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Convenient to pass them in separately
|
||||
template <class T0, class T1, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T0 const& t0, T1 const& t1, Ts const&... ts)
|
||||
{
|
||||
return zip(cute::make_tuple(t0, t1, ts...));
|
||||
}
|
||||
|
||||
//
|
||||
// zip2_by -- A guided zip for rank-2 tuples
|
||||
// Take a tuple like ((A,a),((B,b),(C,c)),d)
|
||||
// and produce a tuple ((A,(B,C)),(a,(b,c),d))
|
||||
// where the rank-2 modes are selected by the terminals of the guide (X,(X,X))
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class TG, int... Is, int... Js>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip2_by(T const& t, TG const& guide, seq<Is...>, seq<Js...>)
|
||||
{
|
||||
// zip2_by produces the modes like ((A,a),(B,b),...)
|
||||
auto split = cute::make_tuple(zip2_by(get<Is>(t), get<Is>(guide))...);
|
||||
|
||||
// Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y))
|
||||
return cute::make_tuple(cute::make_tuple(get<Is,0>(split)...),
|
||||
cute::make_tuple(get<Is,1>(split)..., get<Js>(t)...));
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class TG>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip2_by(T const& t, TG const& guide)
|
||||
{
|
||||
if constexpr (is_tuple<TG>::value) {
|
||||
constexpr int TR = tuple_size<T>::value;
|
||||
constexpr int GR = tuple_size<TG>::value;
|
||||
static_assert(TR >= GR, "Mismatched ranks");
|
||||
return detail::zip2_by(t, guide,
|
||||
make_range< 0, GR>{},
|
||||
make_range<GR, TR>{});
|
||||
} else {
|
||||
static_assert(tuple_size<T>::value == 2, "Mismatched ranks");
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
Reference in New Issue
Block a user