Files
cutlass/include/cutlass/fast_math.h
Manish Gupta 1ac4559d12 Cutlass 2.6 Update 1 (#301)
* cutlass 2.6 update

* remove debug prints
2021-07-27 17:58:30 -07:00

748 lines
20 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#include <cmath>
#include <type_traits>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/uint128.h"
#include "cutlass/coord.h"
#include "cutlass/numeric_types.h"
/**
* \file
* \brief Math utilities
*/
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Static math utilities
******************************************************************************/
/// Mixed precision dot product
template <typename Index, typename LongIndex, int N>
CUTLASS_HOST_DEVICE LongIndex dot(
Coord<N, Index> const &coord,
Coord<N, LongIndex> const &stride,
LongIndex acc = LongIndex()) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < N; ++n) {
acc += LongIndex(coord[n]) * stride[n];
}
return acc;
}
/**
* Statically determine if N is a power-of-two
*/
template <int N>
struct is_pow2 {
static bool const value = ((N & (N - 1)) == 0);
};
/**
* Statically determine log2(N), rounded down
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_down {
/// Static logarithm value
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_down<N, 1, Count> {
enum { value = Count };
};
/**
* Statically determine log2(N), rounded up
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_up {
/// Static logarithm value
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_up<N, 1, Count> {
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
};
/**
* Statically estimate sqrt(N) to the nearest power-of-two
*/
template <int N>
struct sqrt_est {
enum { value = 1 << (log2_up<N>::value / 2) };
};
/**
* For performing a constant-division with a compile-time assertion that the
* Divisor evenly-divides the Dividend.
*/
template <int Dividend, int Divisor>
struct divide_assert {
enum { value = Dividend / Divisor };
static_assert((Dividend % Divisor == 0), "Not an even multiple");
};
/******************************************************************************
* Rounding
******************************************************************************/
/**
* Round dividend up to the nearest multiple of divisor
*/
template <typename dividend_t, typename divisor_t>
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
return ((dividend + divisor - 1) / divisor) * divisor;
}
/**
* Greatest common divisor
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
for (;;) {
if (a == 0) return b;
b %= a;
if (b == 0) return a;
a %= b;
}
}
/**
* Least common multiple
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
value_t temp = gcd(a, b);
return temp ? (a / temp * b) : 0;
}
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
CUTLASS_HOST_DEVICE
constexpr int round_up(int a, int b) {
return ((a + b - 1) / b) * b;
}
/// Returns the ceiling of (a / b)
CUTLASS_HOST_DEVICE
constexpr int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/**
* log2 computation, what's the
* difference between the below codes and
* log2_up/down codes?
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t clz(value_t x) {
for (int i = 31; i >= 0; --i) {
if ((1 << i) & x) return 31 - i;
}
return 32;
}
template <typename value_t>
CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
int a = int(31 - clz(x));
a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
return a;
}
/**
* Find divisor, using find_log2
*/
CUTLASS_HOST_DEVICE
void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) {
if (denom == 1) {
mul = 0;
shr = 0;
} else {
unsigned int p = 31 + find_log2(denom);
unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom));
mul = m;
shr = p - 32;
}
}
/**
* Find quotient and remainder using device-side intrinsics
*/
CUTLASS_HOST_DEVICE
void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if div != 1, else simply copy the source.
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
#else
quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src);
#endif
// The remainder.
rem = src - (quo * div);
}
// For long int input
CUTLASS_HOST_DEVICE
void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if div != 1, else simply copy the source.
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
#else
quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src);
#endif
// The remainder.
rem = src - (quo * div);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Object to encapsulate the fast division+modulus operation.
///
/// This object precomputes two values used to accelerate the computation and is best used
/// when the divisor is a grid-invariant. In this case, it may be computed in host code and
/// marshalled along other kernel arguments using the 'Params' pattern.
///
/// Example:
///
///
/// int quotient, remainder, dividend, divisor;
///
/// FastDivmod divmod(divisor);
///
/// divmod(quotient, remainder, dividend);
///
/// // quotient = (dividend / divisor)
/// // remainder = (dividend % divisor)
///
struct FastDivmod {
int divisor;
unsigned int multiplier;
unsigned int shift_right;
/// Construct the FastDivmod object, in host code ideally.
///
/// This precomputes some values based on the divisor and is computationally expensive.
CUTLASS_HOST_DEVICE
FastDivmod(): divisor(0), multiplier(0), shift_right(0) { }
CUTLASS_HOST_DEVICE
FastDivmod(int divisor_): divisor(divisor_) {
find_divisor(multiplier, shift_right, divisor);
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int &remainder, int dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
///
/// Simply returns the quotient
CUTLASS_HOST_DEVICE
int divmod(int &remainder, int dividend) const {
int quotient;
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
return quotient;
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int64_t &remainder, int64_t dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
int divmod(int64_t &remainder, int64_t dividend) const {
int quotient;
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
return quotient;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Object to encapsulate the fast division+modulus operation for 64b integer division.
///
/// This object precomputes two values used to accelerate the computation and is best used
/// when the divisor is a grid-invariant. In this case, it may be computed in host code and
/// marshalled along other kernel arguments using the 'Params' pattern.
///
/// Example:
///
///
/// uint64_t quotient, remainder, dividend, divisor;
///
/// FastDivmodU64 divmod(divisor);
///
/// divmod(quotient, remainder, dividend);
///
/// // quotient = (dividend / divisor)
/// // remainder = (dividend % divisor)
///
struct FastDivmodU64 {
uint64_t divisor;
uint64_t multiplier;
unsigned int shift_right;
unsigned int round_up;
//
// Static methods
//
/// Computes b, where 2^b is the greatest power of two that is less than or equal to x
CUTLASS_HOST_DEVICE
static uint32_t integer_log2(uint64_t x) {
uint32_t n = 0;
while (x >>= 1) {
++n;
}
return n;
}
/// Default ctor
CUTLASS_HOST_DEVICE
FastDivmodU64(): divisor(0), multiplier(0), shift_right(0), round_up(0) { }
/// Construct the FastDivmod object, in host code ideally.
///
/// This precomputes some values based on the divisor and is computationally expensive.
CUTLASS_HOST_DEVICE
FastDivmodU64(uint64_t divisor_): divisor(divisor_), multiplier(1), shift_right(0), round_up(0) {
if (divisor) {
shift_right = integer_log2(divisor);
if ((divisor & (divisor - 1)) == 0) {
multiplier = 0;
}
else {
uint64_t power_of_two = (uint64_t(1) << shift_right);
uint64_t multiplier_lo = uint128_t(0, power_of_two) / divisor;
multiplier = uint128_t(power_of_two, power_of_two) / divisor;
round_up = (multiplier_lo == multiplier ? 1 : 0);
}
}
}
/// Returns the quotient of floor(dividend / divisor)
CUTLASS_HOST_DEVICE
uint64_t divide(uint64_t dividend) const {
uint64_t quotient = 0;
#ifdef __CUDA_ARCH__
uint64_t x = dividend;
if (multiplier) {
x = __umul64hi(dividend + round_up, multiplier);
}
quotient = (x >> shift_right);
#else
// TODO - use proper 'fast' division here also. No reason why x86-code shouldn't be optimized.
quotient = dividend / divisor;
#endif
return quotient;
}
/// Computes the remainder given a computed quotient and dividend
CUTLASS_HOST_DEVICE
uint64_t modulus(uint64_t quotient, uint64_t dividend) const {
return uint32_t(dividend - quotient * divisor);
}
/// Returns the quotient of floor(dividend / divisor) and computes the remainder
CUTLASS_HOST_DEVICE
uint64_t divmod(uint64_t &remainder, uint64_t dividend) const {
uint64_t quotient = divide(dividend);
remainder = modulus(quotient, dividend);
return quotient;
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(uint64_t &quotient, uint64_t &remainder, uint64_t dividend) const {
quotient = divmod(remainder, dividend);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes the coordinate decomposition from a linear index (64-bit linear index => coord<int32_t>)
///
/// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that
/// a coordinate of <Rank> indices can be decomposed by <Rank - 1> div/mod operations.
/// Note, is assumed that element divmod[0] divides by extent[1].
///
/// For example, assume 4-D coordinate (n, p, q, c) is mapped to a linear index `npqc`. This
/// can be decomposed via three divide and modulus operations:
///
/// c = npqc % C; | divmod[2] = FastDivmodU64(C)
/// npq = npqc / C; | coord[3] = c
///
/// q = npq % Q; | divmod[1] = FastDivmodU64(Q)
/// np = npq / Q; | coord[2] = q
///
/// p = np % P; | divmod[0] = FastDivmodU64(P)
/// n = np / P; | coord[1] = p
///
/// | coord[0] = n
///
template <int Rank>
CUTLASS_HOST_DEVICE Coord<Rank> CoordinateDecomposition(
uint64_t linear_idx, ///< Linear index to decompose
FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects
static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater.");
Coord<Rank> coord;
CUTLASS_PRAGMA_UNROLL
for (int i = Rank; i > 1; --i) {
uint64_t remainder;
linear_idx = divmod[i - 2].divmod(remainder, linear_idx);
coord[i - 1] = int(remainder);
}
coord[0] = int(linear_idx);
return coord;
}
/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord<int32_t>)
template <int Rank>
CUTLASS_HOST_DEVICE Coord<Rank> CoordinateDecomposition(
int linear_idx, ///< Linear index to decompose
FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects
static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater.");
Coord<Rank> coord;
CUTLASS_PRAGMA_UNROLL
for (int i = Rank; i > 1; --i) {
int remainder;
linear_idx = divmod[i - 2].divmod(remainder, linear_idx);
coord[i - 1] = int(remainder);
}
coord[0] = int(linear_idx);
return coord;
}
template <int Rank>
CUTLASS_HOST_DEVICE Coord<Rank> CoordinateDecompositionLittleEndian(
uint64_t linear_idx, ///< Linear index to decompose
FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects
static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater.");
Coord<Rank> coord;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank - 1; ++i) {
uint64_t remainder;
linear_idx = divmod[i].divmod(remainder, linear_idx);
coord[i] = int(remainder);
}
coord[Rank - 1] = int(linear_idx);
return coord;
}
/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord<int32_t>)
template <int Rank>
CUTLASS_HOST_DEVICE Coord<Rank> CoordinateDecompositionLittleEndian(
int linear_idx, ///< Linear index to decompose
FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects
static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater.");
Coord<Rank> coord;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank - 1; ++i) {
int remainder;
linear_idx = divmod[i].divmod(remainder, linear_idx);
coord[i] = int(remainder);
}
coord[Rank - 1] = int(linear_idx);
return coord;
}
/// Safely computes the offset of a linear index in bytes for all types
template <typename Element>
CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) {
static_assert(
(sizeof_bits<Element>::value >= 8 && !(sizeof_bits<Element>::value % 8)) ||
(sizeof_bits<Element>::value < 8 && !(8 % sizeof_bits<Element>::value)),
"Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size.");
if (sizeof_bits<Element>::value >= 8) {
return index * (sizeof_bits<Element>::value / 8);
}
else {
int const kElementsPerByte = ((8 / sizeof_bits<Element>::value) + ((sizeof_bits<Element>::value >= 8) ? 1 : 0));
return index / kElementsPerByte;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Min/Max
/////////////////////////////////////////////////////////////////////////////////////////////////
template <int A, int B>
struct Min {
static int const kValue = (A < B) ? A : B;
};
template <int A, int B>
struct Max {
static int const kValue = (A > B) ? A : B;
};
CUTLASS_HOST_DEVICE
constexpr int const_min(int a, int b) {
return (b < a ? b : a);
}
CUTLASS_HOST_DEVICE
constexpr int const_max(int a, int b) {
return (b > a ? b : a);
}
template <typename T>
CUTLASS_HOST_DEVICE
T fast_min(T a, T b) {
return (b < a ? b : a);
}
template <>
CUTLASS_HOST_DEVICE
float fast_min(float a, float b) {
return fminf(a, b);
}
template <typename T>
CUTLASS_HOST_DEVICE
T fast_max(T a, T b) {
return (a < b ? b : a);
}
template <>
CUTLASS_HOST_DEVICE
float fast_max(float a, float b) {
return fmaxf(a, b);
}
CUTLASS_HOST_DEVICE
float fast_cos(float theta) {
#if defined(__CUDA_ARCH__)
return ::cosf(theta);
#else
return std::cos(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_cos(double theta) {
#if defined(__CUDA_ARCH__)
return ::cos(theta);
#else
return std::cos(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_sin(float theta) {
#if defined(__CUDA_ARCH__)
return ::sinf(theta);
#else
return std::sin(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_sin(double theta) {
#if defined(__CUDA_ARCH__)
return ::sin(theta);
#else
return std::sin(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_acos(float theta) {
#if defined(__CUDA_ARCH__)
return ::acosf(theta);
#else
return std::acos(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_acos(double theta) {
#if defined(__CUDA_ARCH__)
return ::acos(theta);
#else
return std::acos(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_asin(float theta) {
#if defined(__CUDA_ARCH__)
return ::asinf(theta);
#else
return std::asin(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_asin(double theta) {
#if defined(__CUDA_ARCH__)
return ::asin(theta);
#else
return std::asin(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_sqrt(float theta) {
#if defined(__CUDA_ARCH__)
return ::sqrtf(theta);
#else
return std::sqrt(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_sqrt(double theta) {
#if defined(__CUDA_ARCH__)
return ::sqrt(theta);
#else
return std::sqrt(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_exp(float x) {
#if defined(__CUDA_ARCH__)
return ::exp(x);
#else
return std::exp(x);
#endif
}
CUTLASS_HOST_DEVICE
double fast_exp(double x) {
#if defined(__CUDA_ARCH__)
return ::exp(x);
#else
return std::exp(x);
#endif
}
CUTLASS_HOST_DEVICE
float fast_log(float x) {
#if defined(__CUDA_ARCH__)
return ::logf(x);
#else
return std::log(x);
#endif
}
CUTLASS_HOST_DEVICE
double fast_log(double x) {
#if defined(__CUDA_ARCH__)
return ::log(x);
#else
return std::log(x);
#endif
}
CUTLASS_HOST_DEVICE
float fast_tanh(float x) {
#if defined(__CUDA_ARCH__)
return ::tanhf(x);
#else
return std::tanh(x);
#endif
}
CUTLASS_HOST_DEVICE
double fast_tanh(double x) {
#if defined(__CUDA_ARCH__)
return ::tanh(x);
#else
return std::tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////