CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
#pragma once
#include <cstdint>
#include <cute/util/type_traits.hpp>
//#if defined(__CUDA_ARCH__)
//# include <cuda/std/complex>
@ -38,13 +38,37 @@
//# include <complex>
//#endif
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
// on line 656 of thrust/detail/type_traits.h.
// These pragmas suppress the warnings.
// Suppress warnings for code in Thrust headers.
#if defined(_MSC_VER)
// We check for MSVC first, because MSVC also defines __GNUC__.
// It's common for non-GCC compilers that emulate GCC's behavior
// to define __GNUC__.
//
// thrust/complex.h triggers MSVC's warning on conversion
// from double to float (or const float) ("possible loss of data").
// MSVC treats this as an error by default (at least with
// CUTLASS's default CMake configuration).
#pragma warning( push )
#pragma warning( disable : 4244 )
#elif defined(__GNUC__)
// With GCC + CUDA 11.4, builds show spurious "-Wconversion"
// warnings on line 656 of thrust/detail/type_traits.h.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
#endif
#if defined(__CUDACC_RTC__)
#include <cuda/std/complex>
#else
#include <thrust/complex.h>
#endif
#if defined(_MSC_VER)
#pragma warning( pop )
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
#include <cute/config.hpp>
@ -62,7 +86,11 @@ namespace cute
//template <class T>
//using complex = thrust::complex<T>;
#if defined(__CUDACC_RTC__)
using cuda::std::complex;
#else
using thrust::complex;
#endif
template <class T>
CUTE_HOST_DEVICE
@ -147,6 +175,7 @@ struct is_complex<complex<T>> {
//////////////////////////////////////////////////////////////////////////////////////////////////
// Display utilities
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
{
@ -159,5 +188,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
return os << _r;
}
}
#endif
} // end namespace cute