v4.0 update. (#2371)
This commit is contained in:
@ -235,6 +235,63 @@ raw_pointer_cast(counting_iterator<T> const& x) {
|
||||
return x.n_;
|
||||
}
|
||||
|
||||
//
|
||||
// transform_iterator
|
||||
//
|
||||
|
||||
template <class Fn, class Iter>
|
||||
struct transform_iter
|
||||
{
|
||||
using iterator = Iter;
|
||||
// using reference = typename iterator_traits<iterator>::reference;
|
||||
// using element_type = typename iterator_traits<iterator>::element_type;
|
||||
// using value_type = typename iterator_traits<iterator>::value_type;
|
||||
|
||||
Fn fn_;
|
||||
iterator ptr_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
transform_iter(Fn fn, iterator ptr = {}) : fn_(fn), ptr_(ptr) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator*() const { return fn_(*ptr_); }
|
||||
|
||||
template <class Index>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator[](Index const& i) const { return fn_(ptr_[i]); }
|
||||
|
||||
template <class Index>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto operator+(Index const& i) const { return transform_iter<Fn, decltype(ptr_+i)>{fn_, ptr_+i}; }
|
||||
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator==(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ == y.ptr_; }
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator!=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ != y.ptr_; }
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator< (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ < y.ptr_; }
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator<=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ <= y.ptr_; }
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator> (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ > y.ptr_; }
|
||||
template <class IterY>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
friend bool operator>=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ >= y.ptr_; }
|
||||
};
|
||||
|
||||
template <class Fn, class Iterator>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_transform_iter(Fn const& fn, Iterator const& ptr)
|
||||
{
|
||||
return transform_iter<Fn,Iterator>(fn,ptr);
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
@ -251,12 +308,24 @@ CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
|
||||
printf("counting_iter("); print(ptr.n_); printf(")");
|
||||
}
|
||||
|
||||
template <class Fn, class Iterator>
|
||||
CUTE_HOST_DEVICE void print(transform_iter<Fn,Iterator> ptr)
|
||||
{
|
||||
printf("trans_"); print(ptr.ptr_);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
|
||||
{
|
||||
return os << "counting_iter(" << ptr.n_ << ")";
|
||||
}
|
||||
|
||||
template <class Fn, class Iterator>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, transform_iter<Fn,Iterator> ptr)
|
||||
{
|
||||
return os << "trans_" << ptr.ptr_;
|
||||
}
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user