CUTLASS 2.6 (#298)

CUTLASS 2.6
This commit is contained in:
Manish Gupta
2021-07-22 21:40:53 -07:00
committed by GitHub
parent 6c29fe20ba
commit e5d51840e8
308 changed files with 32408 additions and 4722 deletions

View File

@ -38,6 +38,8 @@
#include "cutlass/bfloat16.h"
#include "cutlass/tfloat32.h"
#include "cutlass/fast_math.h"
#if !defined(__CUDACC_RTC__)
#include <iosfwd>
#endif
@ -442,16 +444,16 @@ CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
/// Computes the complex exponential of z.
template <typename T>
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
return complex<T>(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z)));
}
/// Computes the complex exponential of z.
/// Computes the log of z
template <typename T>
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
return complex<T>(log(abs(z)), arg(z));
}
/// Computes the complex exponential of z.
/// Computes the log base 10 of z
template <typename T>
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
return log(z) / T(log(T(10)));
@ -484,6 +486,9 @@ template <typename T>
struct RealType< complex<T> > {
using Type = T;
/// Number of elements
static int const kExtent = 2;
CUTLASS_HOST_DEVICE
static complex<T> from_real(double x) {
return complex<T>(static_cast<T>(x));