CUTLASS 3.8 Release (#2059)
* CUTLASS 3.8 Release
* update
* Update README.md
* Revert "Update README.md"
This reverts commit b353e36fe8.
* update
* update
---------
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -284,6 +284,96 @@ recast_ptr(rmem_ptr<P> const& ptr) {
|
||||
return make_rmem_ptr(recast_ptr<NewT>(ptr.get()));
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// tmem_ptr -- a typed, word-addressed, non-dereferencable "pointer"
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct tmem_ptr
|
||||
{
|
||||
using value_type = remove_cv_t<T>;
|
||||
using element_type = T;
|
||||
using reference = T;
|
||||
|
||||
// Right-shift value for the offset scaling -- TMEM uses word-addressing
|
||||
static constexpr int32_t OffsetShift = log_2(trait_ratio(sizeof_bits<uint32_t>{}, sizeof_bits<T>{}));
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr(uint32_t addr = 0) : addr_(addr) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t const& get() const {
|
||||
return addr_;
|
||||
}
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t& get() {
|
||||
return addr_;
|
||||
}
|
||||
|
||||
template <class T_ = T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
value_type operator*() const {
|
||||
static_assert(dependent_false<T_>, "Attempting to dereference a tmem_ptr, want raw_pointer_cast() for address instead?");
|
||||
return value_type{};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](uint32_t const& i) const { return *(*this + i); }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr operator+(uint32_t const& i) const {
|
||||
//return {addr_ + shiftr(i, OffsetShift)}; // Shift the offset for word-addressing
|
||||
return {addr_ + rotr(i, OffsetShift)}; // Rotate the offset to keep subword indices in the unused high 8bits for debug
|
||||
}
|
||||
|
||||
// TMEM "Address" with active mask 0x007F.01FF
|
||||
// The upper 16 bits, the 0x007F portion, refers to the 128 DP lanes
|
||||
// The lower 16 bits, the 0x01FF portion, refers to the 512 COL lanes
|
||||
union {
|
||||
uint32_t addr_;
|
||||
struct {
|
||||
uint16_t col_;
|
||||
uint8_t dp_;
|
||||
uint8_t idx_; // Hijack the top 8bits for the sub-word idx to avoid an extra reg.
|
||||
// Assert this is 0 on every access?
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <class T, class = void>
|
||||
struct is_tmem : false_type {};
|
||||
template <class T> // Found the tmem
|
||||
struct is_tmem<tmem_ptr<T>> : true_type {};
|
||||
template <class P> // Recurse on ::iterator, if possible
|
||||
struct is_tmem<P, void_t<typename P::iterator>> : is_tmem<typename P::iterator> {};
|
||||
template <class P>
|
||||
constexpr bool is_tmem_v = is_tmem<P>::value;
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr<T>
|
||||
make_tmem_ptr(uint32_t addr = 0) {
|
||||
return tmem_ptr<T>(addr);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t
|
||||
raw_pointer_cast(tmem_ptr<T> const& ptr) {
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
// TMEM accounts for subword/superword elements already due to the offset shift based on sizeof_bits
|
||||
// Thus, this is a trivial recast equivalent to reinterpret_cast<NewT*>
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_ptr(tmem_ptr<T> const& ptr) {
|
||||
return tmem_ptr<NewT>{ptr.addr_};
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
@ -306,6 +396,14 @@ CUTE_HOST_DEVICE void print(rmem_ptr<T> ptr)
|
||||
printf("rmem_"); print(ptr.get());
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(tmem_ptr<T> ptr)
|
||||
{
|
||||
printf("tmem_["); print(sizeof_bits<T>::value); printf("b](0x%04x.%04x)", ptr.addr_ >> 16, ptr.addr_ & 0xFFFF);
|
||||
}
|
||||
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> ptr)
|
||||
@ -325,6 +423,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> ptr)
|
||||
return os << "rmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, tmem_ptr<T> ptr)
|
||||
{
|
||||
return os << "tmem_[" << int(sizeof_bits<T>::value) << "b](" << ptr.addr_ << ")";
|
||||
}
|
||||
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user