v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -235,6 +235,63 @@ raw_pointer_cast(counting_iterator<T> const& x) {
return x.n_;
}
//
// transform_iterator
//
template <class Fn, class Iter>
struct transform_iter
{
using iterator = Iter;
// using reference = typename iterator_traits<iterator>::reference;
// using element_type = typename iterator_traits<iterator>::element_type;
// using value_type = typename iterator_traits<iterator>::value_type;
Fn fn_;
iterator ptr_;
CUTE_HOST_DEVICE constexpr
transform_iter(Fn fn, iterator ptr = {}) : fn_(fn), ptr_(ptr) {}
CUTE_HOST_DEVICE constexpr
decltype(auto) operator*() const { return fn_(*ptr_); }
template <class Index>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator[](Index const& i) const { return fn_(ptr_[i]); }
template <class Index>
CUTE_HOST_DEVICE constexpr
auto operator+(Index const& i) const { return transform_iter<Fn, decltype(ptr_+i)>{fn_, ptr_+i}; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator==(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ == y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator!=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ != y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator< (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ < y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator<=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ <= y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator> (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ > y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator>=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ >= y.ptr_; }
};
template <class Fn, class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_transform_iter(Fn const& fn, Iterator const& ptr)
{
return transform_iter<Fn,Iterator>(fn,ptr);
}
//
// Display utilities
//
@ -251,12 +308,24 @@ CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
printf("counting_iter("); print(ptr.n_); printf(")");
}
template <class Fn, class Iterator>
CUTE_HOST_DEVICE void print(transform_iter<Fn,Iterator> ptr)
{
printf("trans_"); print(ptr.ptr_);
}
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
{
return os << "counting_iter(" << ptr.n_ << ")";
}
template <class Fn, class Iterator>
CUTE_HOST std::ostream& operator<<(std::ostream& os, transform_iter<Fn,Iterator> ptr)
{
return os << "trans_" << ptr.ptr_;
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute