@ -38,6 +38,8 @@
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/tfloat32.h"
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <iosfwd>
|
||||
#endif
|
||||
@ -442,16 +444,16 @@ CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
|
||||
/// Computes the complex exponential of z.
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
|
||||
return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
|
||||
return complex<T>(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z)));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
/// Computes the log of z
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
|
||||
return complex<T>(log(abs(z)), arg(z));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
/// Computes the log base 10 of z
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
|
||||
return log(z) / T(log(T(10)));
|
||||
@ -484,6 +486,9 @@ template <typename T>
|
||||
struct RealType< complex<T> > {
|
||||
using Type = T;
|
||||
|
||||
/// Number of elements
|
||||
static int const kExtent = 2;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static complex<T> from_real(double x) {
|
||||
return complex<T>(static_cast<T>(x));
|
||||
|
||||
Reference in New Issue
Block a user