CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@ -51,6 +51,13 @@ template <class T>
struct has_dereference<T, void_t<decltype(*declval<T>())>> : true_type {
};
template <class T>
CUTE_HOST_DEVICE constexpr
T*
raw_pointer_cast(T* ptr) {
return ptr;
}
//
// Pointer categories
//
@ -92,13 +99,20 @@ struct device_ptr
CUTE_HOST_DEVICE constexpr friend
ptrdiff_t operator-(device_ptr<T,DerivedType> const& a,
device_ptr<T,DerivedType> const& b) {
device_ptr<T,DerivedType> const& b) {
return a.ptr_ - b.ptr_;
}
T* ptr_;
};
template <class T, class D>
CUTE_HOST_DEVICE constexpr
T*
raw_pointer_cast(device_ptr<T,D> ptr) {
return ptr.get();
}
//
// gmem_ptr
//
@ -122,6 +136,24 @@ make_gmem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
CUTE_HOST_DEVICE constexpr
gmem_ptr<T const>
make_gmem_ptr(void const* ptr) {
return {reinterpret_cast<T const*>(ptr)};
}
// nullptr_t overloads are needed because otherwise,
// make_gmem_ptr<float>(nullptr) will be ambiguous,
// as std::nullptr_t can be converted to any pointer
// or pointer to member type.
template <class T>
CUTE_HOST_DEVICE constexpr
gmem_ptr<T>
make_gmem_ptr(decltype(nullptr)) { // nullptr_t
return {static_cast<T*>(nullptr)};
}
template <class T>
struct is_gmem<gmem_ptr<T>> : true_type {};
@ -148,6 +180,13 @@ make_smem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
CUTE_HOST_DEVICE constexpr
smem_ptr<T const>
make_smem_ptr(void const* ptr) {
return {reinterpret_cast<T const*>(ptr)};
}
template <class T>
struct is_smem<smem_ptr<T>> : true_type {};
@ -174,6 +213,13 @@ make_rmem_ptr(void* ptr) {
return {reinterpret_cast<T*>(ptr)};
}
template <class T>
CUTE_HOST_DEVICE constexpr
rmem_ptr<T const>
make_rmem_ptr(void const* ptr) {
return {reinterpret_cast<T const*>(ptr)};
}
template <class T>
struct is_rmem<rmem_ptr<T>> : true_type {};