CUTLASS v1.0 release

This commit is contained in:
akerr
2018-05-16 11:44:56 -07:00
parent 901287175f
commit 2028ebe120
1830 changed files with 308993 additions and 11173 deletions

254
tools/util/command_line.h Normal file
View File

@ -0,0 +1,254 @@
/******************************************************************************
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Utility for parsing command line arguments
*/
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#include <cuda_runtime.h>
namespace cutlass {
/******************************************************************************
* command_line
******************************************************************************/
/**
* Utility for parsing command line arguments
*/
struct CommandLine {
std::vector<std::string> keys;
std::vector<std::string> values;
std::vector<std::string> args;
/**
* Constructor
*/
CommandLine(int argc, const char** argv) : keys(10), values(10) {
using namespace std;
for (int i = 1; i < argc; i++) {
string arg = argv[i];
if ((arg[0] != '-') || (arg[1] != '-')) {
args.push_back(arg);
continue;
}
string::size_type pos;
string key, val;
if ((pos = arg.find('=')) == string::npos) {
key = string(arg, 2, arg.length() - 2);
val = "";
} else {
key = string(arg, 2, pos - 2);
val = string(arg, pos + 1, arg.length() - 1);
}
keys.push_back(key);
values.push_back(val);
}
}
/**
* Checks whether a flag "--<flag>" is present in the commandline
*/
bool check_cmd_line_flag(const char* arg_name) const {
using namespace std;
for (int i = 0; i < int(keys.size()); ++i) {
if (keys[i] == string(arg_name)) return true;
}
return false;
}
/**
* Returns number of naked (non-flag and non-key-value) commandline parameters
*/
template <typename value_t>
int num_naked_args() const {
return args.size();
}
/**
* Returns the commandline parameter for a given index (not including flags)
*/
template <typename value_t>
void get_cmd_line_argument(int index, value_t& val) const {
using namespace std;
if (index < args.size()) {
istringstream str_stream(args[index]);
str_stream >> val;
}
}
/**
* Returns the commandline parameter for a given index (not including flags)
*/
void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
val = _default;
if (check_cmd_line_flag(arg_name)) {
std::string value;
get_cmd_line_argument(arg_name, value);
val = !(value == "0" || value == "false");
}
}
/**
* Returns the value specified for a given commandline parameter --<flag>=<value>
*/
template <typename value_t>
void get_cmd_line_argument(const char* arg_name,
value_t& val,
value_t const& _default = value_t()) const {
using namespace std;
val = _default;
for (int i = 0; i < int(keys.size()); ++i) {
if (keys[i] == string(arg_name)) {
istringstream str_stream(values[i]);
str_stream >> val;
}
}
}
/**
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
*/
template <typename value_t>
void get_cmd_line_arguments(const char* arg_name,
std::vector<value_t>& vals,
char sep = ',') const {
using namespace std;
if (check_cmd_line_flag(arg_name)) {
// Clear any default values
vals.clear();
// Recover from multi-value string
for (int i = 0; i < keys.size(); ++i) {
if (keys[i] == string(arg_name)) {
string val_string(values[i]);
istringstream str_stream(val_string);
string::size_type old_pos = 0;
string::size_type new_pos = 0;
// Iterate <sep>-delimited values
value_t val;
while ((new_pos = val_string.find(sep, old_pos)) != string::npos) {
if (new_pos != old_pos) {
str_stream.width(new_pos - old_pos);
str_stream >> val;
vals.push_back(val);
}
// skip over delimiter
str_stream.ignore(1);
old_pos = new_pos + 1;
}
// Read last value
str_stream >> val;
vals.push_back(val);
}
}
}
}
/**
* Returns the values specified for a given commandline parameter
* --<flag>=<key:value>,<key:value>*
*/
void get_cmd_line_argument_pairs(const char* arg_name,
std::vector<std::pair<std::string, std::string> >& tokens,
char delim = ',',
char sep = ':') const {
if (check_cmd_line_flag(arg_name)) {
std::string value;
get_cmd_line_argument(arg_name, value);
tokenize(tokens, value, delim, sep);
}
}
/**
* The number of pairs parsed
*/
int parsed_argc() const { return (int)keys.size(); }
//-------------------------------------------------------------------------
// Utility functions
//-------------------------------------------------------------------------
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
std::string const& str,
char delim = ',',
char sep = ':') {
// Home-built to avoid Boost dependency
size_t s_idx = 0;
size_t d_idx = std::string::npos;
while (s_idx < str.size()) {
d_idx = str.find_first_of(delim, s_idx);
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
size_t sep_idx = str.find_first_of(sep, s_idx);
size_t offset = 1;
if (sep_idx == std::string::npos || sep_idx >= end_idx) {
sep_idx = end_idx;
offset = 0;
}
std::pair<std::string, std::string> item(
str.substr(s_idx, sep_idx - s_idx),
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
tokens.push_back(item);
s_idx = end_idx + 1;
}
}
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(std::vector<std::string>& tokens,
std::string const& str,
char delim = ',',
char sep = ':') {
typedef std::vector<std::pair<std::string, std::string> > TokenVector;
typedef TokenVector::const_iterator token_iterator;
std::vector<std::pair<std::string, std::string> > token_pairs;
tokenize(token_pairs, str, delim, sep);
for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
tokens.push_back(tok->first);
}
}
};
} // namespace cutlass

178
tools/util/device_memory.h Normal file
View File

@ -0,0 +1,178 @@
/******************************************************************************
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief C++ interface to CUDA device memory management functions.
*/
#include <memory>
#include <cutlass/util/debug.h>
#include <cutlass/util/platform.h>
#include <tools/util/exceptions.h>
namespace cutlass {
namespace device_memory {
/******************************************************************************
* Allocation lifetime
******************************************************************************/
/// Allocate a buffer of \p count elements of type \p T on the current CUDA device
template <typename T>
T* allocate(size_t count = 1) {
T* ptr = 0;
size_t bytes = sizeof(T) * count;
cudaError_t cuda_error = CUDA_PERROR(cudaMalloc((void**)&ptr, bytes));
if (cuda_error != cudaSuccess) {
throw cuda_exception("Failed to allocate memory", cuda_error);
}
return ptr;
}
/// Free the buffer pointed to by \p ptr
template <typename T>
void free(T* ptr) {
if (ptr) {
cudaError_t cuda_error = CUDA_PERROR(cudaFree(ptr));
if (cuda_error != cudaSuccess) {
throw cuda_exception("Failed to free device memory", cuda_error);
}
}
}
/******************************************************************************
* Data movement
******************************************************************************/
template <typename T>
void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) {
size_t bytes = count * sizeof(T);
cudaError_t cuda_error = CUDA_PERROR(cudaMemcpy(dst, src, bytes, kind));
if (cuda_error != cudaSuccess) {
throw cuda_exception("cudaMemcpy() failed", cuda_error);
}
}
template <typename T>
void copy_to_device(T* dst, T const* src, size_t count = 1) {
copy(dst, src, count, cudaMemcpyHostToDevice);
}
template <typename T>
void copy_to_host(T* dst, T const* src, size_t count = 1) {
copy(dst, src, count, cudaMemcpyDeviceToHost);
}
template <typename T>
void copy_device_to_device(T* dst, T const* src, size_t count = 1) {
copy(dst, src, count, cudaMemcpyDeviceToDevice);
}
/// Copies elements from device memory to host-side range
template <typename OutputIterator, typename T>
void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) {
size_t elements = end - begin;
copy_to_host(&*begin, device_begin, elements);
}
/// Copies elements to device memory from host-side range
template <typename T, typename InputIterator>
void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
size_t elements = end - begin;
copy_to_device(device_begin, &*begin, elements);
}
/******************************************************************************
* "Smart" device memory allocation
******************************************************************************/
/// Device allocation abstraction that tracks size and capacity
template <typename T>
struct allocation {
/// Delete functor for CUDA device memory
struct deleter {
void operator()(T* ptr) {
cudaError_t cuda_error = CUDA_PERROR(cudaFree(ptr));
if (cuda_error != cudaSuccess) {
// noexcept
// throw cuda_exception("cudaFree() failed", cuda_error);
return;
}
}
};
/// Number of elements of T allocated on the current CUDA device
size_t capacity;
/// Smart pointer
platform::unique_ptr<T, deleter> smart_ptr;
//
//
//
/// Constructor: allocates no memory
allocation() : capacity(0) {}
/// Constructor: allocates \p capacity elements on the current CUDA device
allocation(size_t _capacity) : smart_ptr(allocate<T>(_capacity)), capacity(_capacity) {}
/// Destructor
~allocation() { reset(); }
/// Returns a pointer to the managed object
T* get() const { return smart_ptr.get(); }
/// Releases the ownership of the managed object (without deleting) and resets capacity to zero
T* release() {
capacity = 0;
return smart_ptr.release();
}
/// Deletes the managed object and resets capacity to zero
void reset() {
capacity = 0;
smart_ptr.reset();
}
/// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
void reset(T* _ptr, size_t _capacity) {
smart_ptr.reset(_ptr);
capacity = _capacity;
}
/// Returns a pointer to the object owned by *this
T* operator->() const { return smart_ptr.get(); }
/// Returns the deleter object which would be used for destruction of the managed object.
deleter& get_deleter() { return smart_ptr.get_deleter(); }
/// Returns the deleter object which would be used for destruction of the managed object (const)
const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
};
} // namespace device_memory
} // namespace cutlass

62
tools/util/exceptions.h Normal file
View File

@ -0,0 +1,62 @@
/******************************************************************************
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief C++ exception semantics for CUDA error codes
*/
#include <cuda_runtime.h>
#include <iosfwd>
#include <stdexcept>
#include <cutlass/util/platform.h>
namespace cutlass {
/// C++ exception wrapper for CUDA \p cudaError_t
class cuda_exception : public std::exception {
public:
/// Constructor
cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {}
/// Returns the underlying CUDA \p cudaError_t
cudaError_t cudaError() const { return err; }
protected:
/// Explanatory string
const char* msg;
/// Underlying CUDA \p cudaError_t
cudaError_t err;
};
/// Writes a cudaError_t to an output stream
inline std::ostream& operator<<(std::ostream& out, cudaError_t result) {
return out << cudaGetErrorString(result);
}
/// Writes a cuda_exception instance to an output stream
inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) {
return out << e.what() << ": " << e.cudaError();
}
} // namespace cutlass

743
tools/util/half.h Normal file
View File

@ -0,0 +1,743 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Host-side implementation of half-precision float
*/
#pragma once
#include <stdint.h>
#include <cmath>
#include <limits>
#include <utility>
#include <utility>
#include <iomanip>
#include <istream>
#include <ostream>
#include <cuda_fp16.h>
namespace cutlass {
/// IEEE binary16 floating-point value
class half_t {
public:
half_t();
half_t(int); /// conversion from integer
half_t(float); /// conversion from fp32
half_t(double); /// conversion from fp64
static half_t bitcast(unsigned short); /// bitcast performs no conversion
static half_t convert(float const&); /// FP conversion - round toward nearest even
static float convert(unsigned short const&); /// floating point conversion to fp32
static half_t zero() { return bitcast(0); } /// +zero
static half_t one() { return bitcast(0x3c00); } /// one
static half_t nan() { return bitcast(0x7fff); } /// canonical not a number
static half_t inf() { return bitcast(0x7c00); } /// +infinity
static half_t ninf() { return bitcast(0xfc00); } /// -infinity
static half_t epsilon() { return bitcast(0x1000); } /// Machine epsilon
bool signbit() const; /// sign bit - true: negative, false: positive
int exponent() const; /// unbiased exponent
unsigned short mantissa() const; /// mantissa bits
bool isfinite() const; /// true if neither inf nor nan
bool isinf() const; /// true if value is + or - infinity
bool isnan() const; /// true if value is not a number
bool isnormal() const; /// true if nonzero value is normalized
bool iszero() const; /// true if value is + or - zero
bool operator==(half_t const&) const;
bool operator!=(half_t const&) const;
bool operator==(float const&) const;
bool operator!=(float const&) const;
bool operator<(half_t const&) const;
bool operator<=(half_t const&) const;
bool operator>(half_t const&) const;
bool operator>=(half_t const&) const;
half_t operator+(half_t const&) const;
half_t operator-() const;
half_t operator-(half_t const&) const;
half_t operator*(half_t const&)const;
half_t operator/(half_t const&) const;
half_t& operator+=(half_t const&);
half_t& operator-=(half_t const&);
half_t& operator*=(half_t const&);
half_t& operator/=(half_t const&);
half_t& operator++();
half_t& operator--();
half_t operator++(int);
half_t operator--(int);
operator bool() const; /// false if zero
operator int() const; /// conversion to int
operator float() const; /// conversion to fp32
operator half() const; /// conversion to half
uint16_t& raw() { return x; }
uint16_t raw() const { return x; }
public:
/// data
unsigned short x;
};
/// Packed pair of half-precision elements
class half2_t {
public:
half2_t();
half2_t(half_t lo, half_t hi);
half2_t(std::pair<float, float> const&);
explicit half2_t(unsigned data);
half2_t operator+(half2_t const&) const;
half2_t operator-(half2_t const&) const;
half2_t operator*(half2_t const&)const;
half2_t operator/(half2_t const&) const;
half2_t& operator+=(half2_t const&);
half2_t& operator-=(half2_t const&);
half2_t& operator*=(half2_t const&);
half2_t& operator/=(half2_t const&);
float dot(half2_t const&) const; /// dot product with single-precision accumulation
float dot(half2_t const&, float) const; /// dot product with single-precision accumulation
half_t doth(half2_t const&) const; /// dot product with half_t-precision accumulation
half_t doth(half2_t const&, half_t) const; /// dot product with half_t-precision accumulation
unsigned packed() const;
operator std::pair<float, float>() const;
operator unsigned() const;
public:
half_t lo;
half_t hi;
};
template <typename Dest, typename Src>
Dest bitcast(Src const&);
template <>
float bitcast<float, unsigned>(unsigned const&);
template <>
float bitcast<float, int>(int const&);
template <>
unsigned bitcast<unsigned, float>(float const&);
template <>
half_t bitcast<half_t, unsigned short>(unsigned short const&);
template <>
unsigned short bitcast<unsigned short, half_t>(half_t const&);
template <>
half bitcast<half, unsigned short>(unsigned short const&);
} // namespace cutlass
cutlass::half_t operator+(float, cutlass::half_t const&);
cutlass::half_t operator-(float, cutlass::half_t const&);
cutlass::half_t operator*(float, cutlass::half_t const&);
cutlass::half_t operator/(float, cutlass::half_t const&);
std::ostream& operator<<(std::ostream&, cutlass::half_t const&); /// writes a half_t
std::istream& operator>>(std::istream&, cutlass::half_t&); /// reads a half_t
#ifdef BOOST_LEXICAL_CAST_INCLUDED
namespace boost {
/// lexical cast from string to half_t
template <>
cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg);
/// lexical cast from half_t to string
template <>
std::string lexical_cast<std::string>(cutlass::half_t const& arg);
} // namespace boost
#endif
#define HLF_MANT_DIG 10
namespace std {
cutlass::half_t abs(cutlass::half_t const&); /// absolute value
bool isnan(cutlass::half_t const&); /// true if argument is NaN
bool isfinite(cutlass::half_t const&); /// true if argument is neither NaN nor infinity
cutlass::half_t nanh(const char* = 0); /// returns a not-a-number
bool isinf(cutlass::half_t const&); /// returns true if argument is infinitey (+ or -)
bool isnormal(
cutlass::half_t const&); /// returns true if argument is normal (neither zero nor infinity)
int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating-point value
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
/// Numeric limits
template <>
struct numeric_limits<cutlass::half_t> {
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_infinity = true;
static bool const has_quiet_NaN = true;
static bool const has_signaling_NaN = false;
static std::float_denorm_style const has_denorm = std::denorm_present;
static bool const has_denorm_loss = true;
static std::float_round_style const round_style = std::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = HLF_MANT_DIG;
static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); }
static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); }
static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); }
/// Returns smallest finite value
static cutlass::half_t epsilon() { return cutlass::half_t::epsilon(); }
/// Returns smallest finite value
static cutlass::half_t round_error() { return cutlass::half_t(0.5f); }
/// Returns smallest finite value
static cutlass::half_t infinity() { return cutlass::half_t::inf(); }
/// Returns smallest finite value
static cutlass::half_t quiet_NaN() { return cutlass::half_t::nan(); }
/// Returns smallest finite value
static cutlass::half_t signaling_NaN() { return cutlass::half_t::nan(); }
/// Returns smallest finite value
static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
};
} // namespace std
//
//
//
inline cutlass::half_t cutlass::half_t::bitcast(unsigned short _x) {
half_t h;
h.x = _x;
return h;
}
/// FP32 -> FP16 conversion - rounds to nearest even
inline cutlass::half_t cutlass::half_t::convert(float const& flt) {
// software implementation rounds toward nearest even
unsigned const& s = *reinterpret_cast<unsigned const*>(&flt);
uint16_t sign = uint16_t((s >> 16) & 0x8000);
int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
int mantissa = s & 0x7fffff;
uint16_t u = 0;
if ((s & 0x7fffffff) == 0) {
// sign-preserving zero
return cutlass::half_t::bitcast(sign);
}
if (exp > 15) {
if (exp == 128 && mantissa) {
// not a number
u = 0x7fff;
} else {
// overflow to infinity
u = sign | 0x7c00;
}
return cutlass::half_t::bitcast(u);
}
int sticky_bit = 0;
if (exp >= -14) {
// normal fp32 to normal fp16
exp = uint16_t(exp + uint16_t(15));
u = uint16_t(((exp & 0x1f) << 10));
u = uint16_t(u | (mantissa >> 13));
} else {
// normal single-precision to subnormal half_t-precision representation
int rshift = (-14 - exp);
if (rshift < 32) {
mantissa |= (1 << 23);
sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
mantissa = (mantissa >> rshift);
u = (uint16_t(mantissa >> 13) & 0x3ff);
} else {
mantissa = 0;
u = 0;
}
}
// round to nearest even
int round_bit = ((mantissa >> 12) & 1);
sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0);
if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
u = uint16_t(u + 1);
}
u |= sign;
return cutlass::half_t::bitcast(u);
}
inline float cutlass::half_t::convert(unsigned short const& h) {
int sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f);
int mantissa = (h & 0x3ff);
unsigned f = 0;
if (exp > 0 && exp < 31) {
// normal
exp += 112;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
} else if (exp == 0) {
if (mantissa) {
// subnormal
exp += 113;
while ((mantissa & (1 << 10)) == 0) {
mantissa <<= 1;
exp--;
}
mantissa &= 0x3ff;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
} else {
// sign-preserving zero
f = (sign << 31);
}
} else if (exp == 31) {
if (mantissa) {
f = 0x7fffffff; // not a number
} else {
f = (0xff << 23) | (sign << 31); // inf
}
}
return *reinterpret_cast<float const*>(&f);
}
inline cutlass::half_t::half_t() {}
inline cutlass::half_t::half_t(int i) { x = convert(float(i)).x; }
inline cutlass::half_t::half_t(float f) { x = convert(f).x; }
inline cutlass::half_t::half_t(double d) { x = convert(float(d)).x; }
inline bool cutlass::half_t::signbit() const { return (x >> 15) & 1; }
inline int cutlass::half_t::exponent() const { return ((x >> 10) & 0x1f) - 15; }
inline unsigned short cutlass::half_t::mantissa() const { return x & 0x3ff; }
inline cutlass::half_t::operator bool() const { return (x & 0x7fff) != 0; }
inline cutlass::half_t::operator int() const { return static_cast<int>(convert(x)); }
inline cutlass::half_t::operator float() const { return convert(x); }
inline cutlass::half_t::operator half() const { return cutlass::bitcast<half, unsigned short>(x); }
inline bool cutlass::half_t::operator==(cutlass::half_t const& h) const {
if (iszero() && h.iszero()) {
return true;
}
return x == h.x;
}
inline bool cutlass::half_t::operator!=(cutlass::half_t const& h) const {
if (iszero() && h.iszero()) {
return false;
}
return x != h.x;
}
inline bool cutlass::half_t::operator==(float const& b) const { return x == half_t(b).x; }
inline bool cutlass::half_t::operator!=(float const& b) const { return x != half_t(b).x; }
inline bool cutlass::half_t::iszero() const { return (x & 0x7fff) == 0; }
inline bool cutlass::half_t::isfinite() const { return (exponent() < 16); }
inline bool cutlass::half_t::isnan() const {
int exp = ((x >> 10) & 0x1f);
if (exp == 0x1f) {
return (x & 0x3ff) != 0;
}
return false;
}
inline bool cutlass::half_t::isinf() const {
int exp = ((x >> 10) & 0x1f);
if (exp == 0x1f) {
return (x & 0x3ff) == 0;
}
return false;
}
inline bool cutlass::half_t::isnormal() const {
int exp = exponent();
return exp > -15 && exp < 16;
}
inline bool cutlass::half_t::operator<(half_t const& h) const {
int sign = ((x >> 15) & 1);
int h_sign = ((h.x >> 15) & 1);
if (sign == h_sign) {
return (x & 0x7fff) < (h.x & 0x7fff);
} else if (sign) {
return true;
}
return false;
}
inline bool cutlass::half_t::operator<=(half_t const& h) const {
int sign = ((x >> 15) & 1);
int h_sign = ((h.x >> 15) & 1);
if (sign == h_sign) {
return (x & 0x7fff) <= (h.x & 0x7fff);
} else if (sign) {
return true;
}
return false;
}
inline bool cutlass::half_t::operator>(half_t const& h) const {
int sign = ((x >> 15) & 1);
int h_sign = ((h.x >> 15) & 1);
if (sign == h_sign) {
return (x & 0x7fff) > (h.x & 0x7fff);
} else if (h_sign) {
return true;
}
return false;
}
inline bool cutlass::half_t::operator>=(half_t const& h) const {
int sign = ((x >> 15) & 1);
int h_sign = ((h.x >> 15) & 1);
if (sign == h_sign) {
return (x & 0x7fff) >= (h.x & 0x7fff);
} else if (h_sign) {
return true;
}
return false;
}
inline cutlass::half_t cutlass::half_t::operator+(cutlass::half_t const& b) const {
return cutlass::half_t(float(*this) + float(b));
}
inline cutlass::half_t cutlass::half_t::operator-() const { return bitcast(x ^ 0x8000); }
inline cutlass::half_t cutlass::half_t::operator-(cutlass::half_t const& b) const {
return cutlass::half_t(float(*this) - float(b));
}
inline cutlass::half_t cutlass::half_t::operator*(cutlass::half_t const& b) const {
return cutlass::half_t(float(*this) * float(b));
}
inline cutlass::half_t cutlass::half_t::operator/(cutlass::half_t const& b) const {
return cutlass::half_t(float(*this) / float(b));
}
inline cutlass::half_t& cutlass::half_t::operator+=(cutlass::half_t const& b) {
*this = cutlass::half_t(float(*this) + float(b));
return *this;
}
inline cutlass::half_t& cutlass::half_t::operator-=(cutlass::half_t const& b) {
*this = cutlass::half_t(float(*this) - float(b));
return *this;
}
inline cutlass::half_t& cutlass::half_t::operator*=(cutlass::half_t const& b) {
*this = cutlass::half_t(float(*this) * float(b));
return *this;
}
inline cutlass::half_t& cutlass::half_t::operator/=(cutlass::half_t const& b) {
*this = cutlass::half_t(float(*this) / float(b));
return *this;
}
inline cutlass::half_t& cutlass::half_t::operator++() {
*this = cutlass::half_t(float(*this) + 1.0f);
return *this;
}
inline cutlass::half_t& cutlass::half_t::operator--() {
*this = cutlass::half_t(float(*this) - 1.0f);
return *this;
}
inline cutlass::half_t cutlass::half_t::operator++(int) {
half_t h = *this;
*this = cutlass::half_t(float(*this) + 1.0f);
return h;
}
inline cutlass::half_t cutlass::half_t::operator--(int) {
half_t h = *this;
*this = cutlass::half_t(float(*this) - 1.0f);
return h;
}
inline cutlass::half_t operator+(float a, cutlass::half_t const& b) {
return cutlass::half_t(a + float(b));
}
inline cutlass::half_t operator-(float a, cutlass::half_t const& b) {
return cutlass::half_t(a - float(b));
}
inline cutlass::half_t operator*(float a, cutlass::half_t const& b) {
return cutlass::half_t(a * float(b));
}
inline cutlass::half_t operator/(float a, cutlass::half_t const& b) {
return cutlass::half_t(a / float(b));
}
//
//
//
inline cutlass::half2_t::half2_t() {}
inline cutlass::half2_t::half2_t(half_t lo, half_t hi) : lo(lo), hi(hi) {}
inline cutlass::half2_t::half2_t(std::pair<float, float> const& p) : lo(p.first), hi(p.second) {}
inline cutlass::half2_t::half2_t(unsigned data)
: lo(half_t::bitcast(uint16_t(data & 0x0ffff))),
hi(half_t::bitcast(uint16_t((data >> 16) & 0x0ffff))) {}
inline cutlass::half2_t cutlass::half2_t::operator+(half2_t const& b) const {
return half2_t(lo + b.lo, hi + b.hi);
}
inline cutlass::half2_t cutlass::half2_t::operator-(half2_t const& b) const {
return half2_t(lo - b.lo, hi - b.hi);
}
inline cutlass::half2_t cutlass::half2_t::operator*(half2_t const& b) const {
return half2_t(lo * b.lo, hi * b.hi);
}
inline cutlass::half2_t cutlass::half2_t::operator/(half2_t const& b) const {
return half2_t(lo / b.lo, hi / b.hi);
}
inline cutlass::half2_t& cutlass::half2_t::operator+=(half2_t const& b) {
lo += b.lo;
hi += b.hi;
return *this;
}
inline cutlass::half2_t& cutlass::half2_t::operator-=(half2_t const& b) {
lo -= b.lo;
hi -= b.hi;
return *this;
}
inline cutlass::half2_t& cutlass::half2_t::operator*=(half2_t const& b) {
lo *= b.lo;
hi *= b.hi;
return *this;
}
inline cutlass::half2_t& cutlass::half2_t::operator/=(half2_t const& b) {
lo /= b.lo;
hi /= b.hi;
return *this;
}
inline float cutlass::half2_t::dot(half2_t const& b) const {
return float(lo) * float(b.lo) + float(hi) * float(b.hi);
}
inline float cutlass::half2_t::dot(half2_t const& b, float c) const { return c + dot(b); }
inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b) const {
return cutlass::half_t(dot(b));
}
inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b, half_t c) const {
return cutlass::half_t(dot(b, float(c)));
}
inline cutlass::half2_t::operator std::pair<float, float>() const {
return std::pair<float, float>(float(lo), float(hi));
}
inline unsigned cutlass::half2_t::packed() const { return (lo.x | (hi.x << 16)); }
inline cutlass::half2_t::operator unsigned() const { return packed(); }
//
//
//
template <>
inline float cutlass::bitcast<float, unsigned>(unsigned const& u) {
return *reinterpret_cast<float const*>(&u);
}
template <>
inline float cutlass::bitcast<float, int>(int const& i) {
return *reinterpret_cast<float const*>(&i);
}
template <>
inline unsigned cutlass::bitcast<unsigned, float>(float const& f) {
return *reinterpret_cast<unsigned const*>(&f);
}
template <>
inline cutlass::half_t cutlass::bitcast<cutlass::half_t, unsigned short>(unsigned short const& s) {
return *reinterpret_cast<cutlass::half_t const*>(&s);
}
template <>
inline unsigned short cutlass::bitcast<unsigned short, cutlass::half_t>(cutlass::half_t const& h) {
return *reinterpret_cast<unsigned short const*>(&h);
}
template <>
inline half cutlass::bitcast<half, unsigned short>(unsigned short const& s) {
return *reinterpret_cast<half const*>(&s);
}
//
// Lexical casts
//
#ifdef BOOST_LEXICAL_CAST_INCLUDED
namespace boost {
template <>
cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg) {
return cutlass::half_t(boost::lexical_cast<float>(arg));
}
template <>
std::string lexical_cast<std::string>(cutlass::half_t const& arg) {
return boost::lexical_cast<std::string>(float(arg));
}
} // namespace boost
#endif
//
// Standard Library Operations
//
// std
namespace std {
inline cutlass::half_t abs(cutlass::half_t const& h) {
return cutlass::half_t::bitcast(h.x & 0x7fff);
}
inline bool isnan(cutlass::half_t const& h) { return h.isnan(); }
inline bool isfinite(cutlass::half_t const& h) { return h.isfinite(); }
inline cutlass::half_t nanh(const char*) { return cutlass::half_t::nan(); }
inline bool isinf(cutlass::half_t const& h) { return h.isinf(); }
inline bool isnormal(cutlass::half_t const& h) { return h.isnormal(); }
inline int fpclassify(cutlass::half_t const& h) {
int exp = h.exponent();
unsigned short mantissa = h.mantissa();
if (exp < -14) {
if (mantissa == 0) {
return FP_ZERO;
} else {
return FP_SUBNORMAL;
}
} else if (exp > 15) {
if (mantissa == 0) {
return FP_INFINITE;
} else {
return FP_NAN;
}
}
return FP_NORMAL;
}
inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
inline cutlass::half_t sqrt(cutlass::half_t const& h) {
return cutlass::half_t(std::sqrt(float(h)));
}
} // namespace std
//
// Stream interactions
//
/// put to stream - half_t-precision types bitcast as unsigned shorts if base is hexadecimal
inline std::ostream& operator<<(std::ostream& out, cutlass::half_t const& h) {
if (out.flags() & std::ios::hex) {
return out << h.x;
} else {
return out << float(h);
}
}
/// read from stream - half_t-precision types parsed as unsigned shorts if base is hexadecimal
inline std::istream& operator>>(std::istream& in, cutlass::half_t& h) {
if (in.flags() & std::ios::hex) {
unsigned short u = 0;
in >> u;
h = cutlass::half_t::bitcast(u);
} else {
float f = 0;
in >> f;
h = cutlass::half_t(f);
}
return in;
}

362
tools/util/host_tensor.h Normal file
View File

@ -0,0 +1,362 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*! \file
\brief Template class to perform computations on tensors and manage memory.
*/
#include <cutlass/cutlass.h>
#include <cutlass/matrix_traits.h>
#include <tools/util/device_memory.h>
#include <tools/util/host_tensor_view.h>
#include <tools/util/type_traits.h>
#include <vector>
namespace cutlass {
template <typename T, bool DeviceBacked_ = true>
class HostTensor : public HostTensorView<T> {
public:
/// Type used for device-side allocations
typedef typename TypeTraits<T>::device_type DeviceType;
/// Base class
typedef HostTensorView<T> Base;
/// If true, allocates device side memory
static bool const DeviceBacked = DeviceBacked_;
/// Rank of tensor
static int const Rank = Base::Rank;
/// Type used to compute the offset of an element to the base of a tensor
typedef typename Base::Offset_t Offset_t;
/// Tensor reference to host memory
typedef typename Base::TensorRef_t TensorRef_t;
/// Tensor reference to device memory
typedef TensorRef<DeviceType, TensorRef_t::Rank> DeviceTensorRef;
/// Tensor reference to constant device memory
typedef TensorRef<DeviceType const, TensorRef_t::Rank> ConstDeviceTensorRef;
/// Coordinate into tensor
typedef typename Base::Coord_t Coord_t;
private:
/// Host-side memory allocation
std::vector<T> host_;
/// Device-side memory
cutlass::device_memory::allocation<DeviceType> device_;
public:
//
// Device and Host Methods
//
/// Default constructor
HostTensor() {}
/// Constructs a Tensor_view from stride and size
HostTensor(Coord_t const& _stride, Coord_t const& _size) { reset(_stride, _size); }
/// Constructs a HostTensor from size - infers strides
HostTensor(Coord_t const& _size) {
Coord_t _stride = make_Coord(
_size.at(2) * _size.at(1) * _size.at(0), _size.at(1) * _size.at(0), _size.at(0), 1);
reset(_stride, _size);
}
/// Returns the number of elements needed to back vector
size_t capacity() { return Base::capacity(); }
/// Returns true if the Tensor_view is bound to some memory
bool good() const { return Base::good(); }
/// Updates the reference and size of a Tensor_view object
void reset(Coord_t const& _stride, Coord_t const& _size) {
size_t _capacity = _size.at(0) * _stride.at(0);
DeviceType* _device_memory = nullptr;
if (DeviceBacked) {
_device_memory = cutlass::device_memory::allocate<DeviceType>(_capacity);
}
host_.clear();
host_.resize(_capacity);
device_.reset(_device_memory, _capacity);
Base::reset(TensorRef_t(host_.data(), _stride), _size);
}
/// Initializes the host tensor as a matrix
void resize_matrix(int rows, int columns, MatrixLayout::Kind layout) {
bool col_major = (layout == MatrixLayout::kColumnMajor);
int ldm = (col_major ? rows : columns);
Coord_t stride = make_Coord(rows * columns, col_major ? 1 : ldm, col_major ? ldm : 1, 1);
Coord_t size = make_Coord(1, rows, columns, 1);
reset(stride, size);
}
/// Simplifies resizing the host tensor
void resize(int elements) { resize_matrix(1, elements, MatrixLayout::kColumnMajor); }
/// Gets pointer to host data
T const* host_data() const { return &host_[0]; }
/// Gets pointer to host data
T* host_data() { return &host_[0]; }
/// Gets pointer to device data
DeviceType* device_data() const { return device_.get(); }
/// Copies data from device to host
void sync_host() {
if (DeviceBacked) {
device_memory::copy_to_host(
host_.data(), reinterpret_cast<T const*>(device_.get()), host_.size());
}
}
/// Copies data from host to device
void sync_device() {
if (DeviceBacked) {
device_memory::copy_to_device(
device_.get(), reinterpret_cast<DeviceType const*>(host_.data()), host_.size());
}
}
/// Copy data from a caller-supplied device pointer
void copy_to_host(DeviceType const *ptr_device) {
device_memory::copy_to_host(
host_.data(), reinterpret_cast<T const *>(ptr_device), host_.size());
}
/// Copies data to a caller-supplied device pointer
void copy_to_device(DeviceType *ptr_device) {
device_memory::copy_to_device(
ptr_device, reinterpret_cast<DeviceType const *>(host_.data()), host_.size());
}
/// Accesses the tensor reference pointing to data
TensorRef_t& host_ref() { return Base::ref(); }
/// Accesses the tensor reference pointing to data
TensorRef_t const& host_ref() const { return Base::ref(); }
/// Accesses the tensor reference pointing to data
DeviceTensorRef device_ref() const { return DeviceTensorRef(device_data(), stride()); }
/// Returns a tensor ref to constant memory on the device
ConstDeviceTensorRef const_device_ref() const {
return ConstDeviceTensorRef(device_data(), stride());
}
/// Accesses the size
Coord_t const& size() const { return Base::size(); }
/// Accesses the size
int size(int dim) const { return Base::size(dim); }
/// Accesses the size
Coord_t const& stride() const { return Base::stride(); }
/// Accesses the size
int stride(int dim) const { return Base::stride(dim); }
/// Returns the index of an element
Offset_t offset(Coord_t const& coord) const { return Base::offset(coord); }
/// Determines whether a location is within a tensor
bool contains(Coord_t const& coord) const { return Base::contains(coord); }
/// Element-wise accessor
T& at(Coord_t const& coord) const { return Base::at(coord); }
/// Element-wise accessor
T& operator[](Coord_t const& coord) { return at(coord); }
/// Element-wise accessor with basic offset
T& at(int idx) const { return Base::at(idx); }
/// Returns a Tensor_view given location and size quantities
TensorView<T> subview(Coord_t const& _location, Coord_t _size) const {
return Base::subview(_location, _size);
}
/// Recurses through all dimensions and applies a unary operation
template <typename F>
void elementwise_in_place(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
Base::elementwise_in_place(op, dim, dst_offset_base);
}
/// Recurses through all dimensions and applies a unary operator, supplying the logical
/// coordinate within the tensor as an argument
template <typename F>
void elementwise_stream(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
Base::elementwise_stream(op, dim, dst_offset_base);
}
/// Recurses through all dimensions and applies a unary operator, supplying the logical
/// coordinate within the tensor as an argument
template <typename F>
void elementwise_generate(F& op,
int dim = 0,
Offset_t dst_offset_base = 0,
Coord_t coord = Coord_t(0)) {
Base::elementwise_generate(op, dim, dst_offset_base, coord);
}
/// Recurses through all dimensions and applies a binary operation
template <typename Src, typename F>
bool elementwise_in_place(F& op,
int dim,
TensorView<Src> const& tensor,
Offset_t dst_offset_base = 0,
Offset_t src_offset_base = 0) {
return Base::elementwise_in_place(op, dim, tensor, dst_offset_base, src_offset_base);
}
/// Accumulate in place
template <typename Src>
TensorView<T>& operator+=(TensorView<Src> const& tensor) {
Base::operator+=(tensor);
sync_device();
return *this;
}
/// Subtract in place
template <typename Src>
TensorView<T>& operator-=(TensorView<Src> const& tensor) {
Base::operator-=(tensor);
sync_device();
return *this;
}
/// Multiply in place
template <typename Src>
TensorView<T>& operator*=(TensorView<Src> const& tensor) {
Base::operator*=(tensor);
sync_device();
return *this;
}
/// Divide in place
template <typename Src>
TensorView<T>& operator/=(TensorView<Src> const& tensor) {
Base::operator/=(tensor);
sync_device();
return *this;
}
/// equality with epsilon tolerance
bool equals(TensorView<T> const& tensor, T epsilon) const {
return Base::equals(tensor, epsilon);
}
/// equality with ulps tolerance
bool bit_equals(TensorView<T> const& tensor, long long ulps_threshold = 0) {
return Base::bit_equals(tensor, ulps_threshold);
}
/// Computes general matrix product among select dimensions of a tensor
/// Assumes:
/// D: number of independent GEMMs to compute
/// H: height of matrix
/// W: width of matrix
template <
/// Data type of A matrix elements
typename A,
/// Data type of B matrix elements
typename B,
/// Data type of "compute" type (i.e. accumulator)
typename Ctype,
/// Data type of scale factors
typename Stype>
void gemm(TensorView<A> const& tensor_a, TensorView<B> const& tensor_b, Stype alpha, Stype beta) {
Base::template gemm<A, B, Ctype, Stype>(tensor_a, tensor_b, alpha, beta);
}
/// Fills with random data
template <typename Gen>
void fill_random(Gen generator) {
Base::fill_random(generator);
sync_device();
}
/// Procedurally assigns elements
template <typename Gen>
void generate(Gen generator) {
Base::generate(generator);
sync_device();
}
/// Procedurally visits elements
template <typename Gen>
void visit(Gen& generator) const {
Base::visit(generator);
}
/// initializes with identity
void fill_identity() {
Base::fill_identity();
sync_device();
}
/// computes elements as a linear combination of their coordinates
void fill_linear(Coord_t v, T offset = T(0)) {
Base::fill_linear(v, offset);
sync_device();
}
/// computes elements as a linear combination of their coordinates
void fill_sequential(T v = T(1), T offset = T(0)) {
Base::fill_sequential(v, offset);
sync_device();
}
/// fills with a value
void fill(T val = T(0)) {
Base::fill(val);
sync_device();
}
/// Copies from external data source and performs type conversion
template <typename Src>
void fill(TensorView<Src> const& tensor) {
Base::fill(tensor);
sync_device();
}
/// Computes the norm of the matrix in double-precision
double norm() const { return Base::norm(); }
};
} // namespace cutlass

View File

@ -0,0 +1,542 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Host-side implementation of useful operations
*/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/tensor_view.h>
#include <tools/util/type_traits.h>
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename SrcType, typename DstType>
struct Cast {
static inline DstType apply(SrcType src) { return static_cast<DstType>(src); };
};
template <>
struct Cast<float, int8_t> {
static inline int8_t apply(float src) {
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
};
};
template <>
struct Cast<float, uint8_t> {
static inline uint8_t apply(float src) {
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
class HostTensorView : public TensorView<T> {
public:
/// Base class
typedef TensorView<T> TensorView_t;
/// Convention: depth is the first dimension
static int const Dim_D = 0;
/// Convention: height is the second dimension
static int const Dim_H = 1;
/// Convention: width is the third dimension
static int const Dim_W = 2;
/// Convention: channel is the second dimension
static int const Dim_C = 3;
/// Rank of tensor
static int const Rank = TensorView_t::Rank;
/// Type used to compute the offset of an element to the base of a tensor
typedef typename TensorView_t::Offset_t Offset_t;
/// Reference and stride
typedef typename TensorView_t::TensorRef_t TensorRef_t;
/// Coordinate into tensor
typedef typename TensorView_t::Coord_t Coord_t;
public:
//
// Device and Host Methods
//
/// Default constructor
HostTensorView() {}
/// Constructs a Tensor_view from a TensorRef and size
HostTensorView(TensorRef_t const& _ref, Coord_t const& _size) : TensorView_t(_ref, _size) {}
/// Accesses the size
Coord_t const& size() const { return TensorView_t::size(); }
/// Accesses the size of a specified dimension
int size(int dim) const { return size().at(dim); }
/// Accesses the stride
Coord_t const& stride() const { return TensorView_t::stride(); }
/// Accesses the stride along a specified dimension
int stride(int dim) const { return stride().at(dim); }
/// Returns the number of scalar elements needed to store tensor
size_t capacity() const { return size(3) * stride(3) * stride(2) * stride(1) * stride(0); }
/// Returns true if the Tensor_view is bound to some memory
bool good() const { return TensorView_t::good(); }
/// Updates the reference and size of a TensorView object
void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
return TensorView_t::reset(_ref, _size);
}
/// Accesses the tensor reference pointing to data
TensorRef_t& ref() { return TensorView_t::ref(); }
/// Accesses the tensor reference pointing to data
TensorRef_t const& ref() const { return TensorView_t::ref(); }
/// Assigns a tensor view
HostTensorView& operator=(TensorView_t const& _tensor) {
reset(_tensor.ref(), _tensor.size());
return *this;
}
/// Returns the index of an element
Offset_t offset(Coord_t const& coord) const { return TensorView_t::offset(coord); }
/// Determines whether a location is within a tensor
bool contains(Coord_t const& coord) const { return TensorView_t::contains(coord); }
/// Element-wise accessor
T& at(Coord_t const& coord) const { return TensorView_t::at(coord); }
/// Element-wise accessor
T& operator[](Coord_t const& coord) const { return at(coord); }
/// Accesses an element with a raw offset
T& at(int idx) const { return TensorView_t::at(idx); }
/// Accesses an element with a raw offset
T& operator[](int idx) const { return at(idx); }
/// Returns a Tensor_view given location and size quantities
TensorView_t subview(Coord_t const& location, Coord_t size) const {
return TensorView_t::subview(location, size);
}
/// Recurses through all dimensions and applies a unary operation in place
template <typename F>
void elementwise_in_place(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
Offset_t dst_offset = dst_offset_base;
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
if (dim < Rank - 1) {
elementwise_in_place(op, dim + 1, dst_offset);
} else {
op(ref().data()[dst_offset]);
}
}
}
/// Recurses through all dimensions and applies a unary operator with no arguments
template <typename F>
void elementwise_stream(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
Offset_t dst_offset = dst_offset_base;
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
if (dim < Rank - 1) {
elementwise_stream(op, dim + 1, dst_offset);
} else {
ref().data()[dst_offset] = op();
}
}
}
/// Recurses through all dimensions and applies a unary operator, supplying the logical
/// coordinate within the tensor as an argument
template <typename F>
void elementwise_generate(F& op,
int dim = 0,
Offset_t dst_offset_base = 0,
Coord_t coord = Coord_t(0)) {
Offset_t dst_offset = dst_offset_base;
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
coord.at(dim) = idx;
if (dim < Rank - 1) {
elementwise_generate(op, dim + 1, dst_offset, coord);
} else {
ref().data()[dst_offset] = op(coord);
}
}
}
/// Recurses through all dimensions and applies a unary operator, supplying the logical
/// coordinate within the tensor as an argument
template <typename F>
void elementwise_visit(F& op,
int dim = 0,
Offset_t dst_offset_base = 0,
Coord_t coord = Coord_t(0)) const {
Offset_t dst_offset = dst_offset_base;
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
coord.at(dim) = idx;
if (dim < Rank - 1) {
elementwise_visit(op, dim + 1, dst_offset, coord);
} else {
op(ref().data()[dst_offset], coord);
}
}
}
/// Recurses through all dimensions and applies a binary operation
template <typename Src, typename F>
bool elementwise_in_place(F& op,
TensorView<Src> const& tensor,
int dim = 0,
Offset_t dst_offset_base = 0,
Offset_t src_offset_base = 0) {
Offset_t dst_offset = dst_offset_base;
Offset_t src_offset = src_offset_base;
if (size().at(dim) != tensor.size().at(dim)) {
return false;
}
for (int idx = 0; idx < size(dim);
++idx, dst_offset += stride(dim), src_offset += tensor.stride(dim)) {
if (dim < Rank - 1) {
elementwise_in_place(op, tensor, dim + 1, dst_offset, src_offset);
} else {
op(data()[dst_offset], tensor.data()[src_offset]);
}
}
return true;
}
template <typename Src>
struct LambdaBinaryAddition {
void operator()(T& a, Src b) const { a += T(b); }
};
template <typename Src>
struct LambdaBinarySubtraction {
void operator()(T& a, Src b) const { a -= T(b); }
};
template <typename Src>
struct LambdaBinaryMultiplication {
void operator()(T& a, Src b) const { a *= T(b); }
};
template <typename Src>
struct LambdaBinaryDivision {
void operator()(T& a, Src b) const { a /= T(b); }
};
/// Accumulate in place
template <typename Src>
TensorView<T>& operator+=(TensorView<Src> const& tensor) {
LambdaBinaryAddition<Src> op;
elementwise_in_place(op, tensor);
return *this;
}
/// Subtract in place
template <typename Src>
TensorView<T>& operator-=(TensorView<Src> const& tensor) {
LambdaBinarySubtraction<Src> op;
elementwise_in_place(op, tensor);
return *this;
}
/// Multiply in place
template <typename Src>
TensorView<T>& operator*=(TensorView<Src> const& tensor) {
LambdaBinaryMultiplication<Src> op;
elementwise_in_place(op, tensor);
return *this;
}
/// Divide in place
template <typename Src>
TensorView<T>& operator/=(TensorView<Src> const& tensor) {
LambdaBinaryDivision<Src> op;
elementwise_in_place(op, tensor);
return *this;
}
/// Comparison operator
struct EqualsOperator {
bool equal;
T eps;
EqualsOperator(T _epsilon) : equal(true), eps(_epsilon) {}
void operator()(T a, T b) {
if (std::abs(T(a - b)) > eps * std::max(std::abs(a), std::abs(b))) {
equal = false;
}
}
};
/// equality with epsilon tolerance
bool equals(TensorView<T> const& tensor, T epsilon) const {
EqualsOperator comparison_op(epsilon);
bool equal_size = elementwise_in_place(comparison_op, tensor);
return equal_size && comparison_op.equal;
}
/// Compares two values which are smaller or equal to a long long int
struct BitEqualsOperator {
bool equal;
long long eps;
uint64_t index;
BitEqualsOperator(long long _ulps_threshold) : equal(true), eps(_ulps_threshold), index(0) {}
void operator()(T a, T b) {
// convert bits to integers
long long bits_a = 0;
long long bits_b = 0;
*reinterpret_cast<T*>(&bits_a) = TypeTraits<T>::remove_negative_zero(a);
*reinterpret_cast<T*>(&bits_b) = TypeTraits<T>::remove_negative_zero(b);
// compute diff
long long ulps = bits_a - bits_b;
if (std::abs(ulps) > eps) {
equal = false;
}
index++;
}
};
/// equality with ulps tolerance
bool bit_equals(TensorView<T> const& tensor, long long ulps_threshold = 0) {
BitEqualsOperator comparison_op(ulps_threshold);
bool equal_size = elementwise_in_place(comparison_op, tensor);
return equal_size && comparison_op.equal;
}
/// Gets naked pointer to data
T* data() const { return TensorView_t::data(); }
/// Computes general matrix product among select dimensions of a tensor
/// Assumes:
/// D: number of independent GEMMs to compute
/// H: height of matrix
/// W: width of matrix
/// C: "channels" of each element
template <typename A, typename B, typename Ctype, typename Stype>
void gemm(TensorView<A> const& tensor_a, TensorView<B> const& tensor_b, Stype alpha, Stype beta) {
int const Batch = size(Dim_D);
int const M = size(Dim_H);
int const N = size(Dim_W);
int const K = tensor_a.size(Dim_W);
int const C = tensor_a.size(Dim_C);
// Sizes must match
if (tensor_a.size(Dim_H) != M || tensor_b.size(Dim_W) != N || tensor_b.size(Dim_C) != C ||
tensor_b.size(Dim_H) != K) {
return;
}
int const Mblock = 32;
int const Nblock = 32;
for (int batch = 0; batch < Batch; ++batch) {
for (int row_block = 0; row_block < M; row_block += Mblock) {
for (int col_block = 0; col_block < N; col_block += Nblock) {
Ctype accum[Mblock][Nblock];
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
accum[i][j] = Ctype(0);
}
}
for (int k_block = 0; k_block < K; ++k_block) {
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
int row = row_block + i;
int col = col_block + j;
if (row < M && col < N) {
for (int channel = 0; channel < C; ++channel) {
Ctype a(tensor_a.at(make_Coord(batch, row, k_block, channel)));
Ctype b(tensor_b.at(make_Coord(batch, k_block, col, channel)));
accum[i][j] += a * b;
}
}
}
}
}
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
int row = row_block + i;
int col = col_block + j;
Coord_t coord = make_Coord(batch, row, col, 0);
if (row < M && col < N) {
at(coord) =
Cast<Stype, T>::apply(alpha * Stype(accum[i][j]) + beta * Stype(at(coord)));
}
}
}
}
}
}
}
/// Fills with random data
template <typename Gen>
void fill_random(Gen generator) {
elementwise_stream(generator);
}
/// Procedurally assigns elements
template <typename Gen>
void generate(Gen generator) {
elementwise_generate(generator);
}
/// Procedurally visits elements
template <typename Gen>
void visit(Gen& generator) const {
elementwise_visit(generator);
}
/// Generator to fill a tensor with the identity matrix
struct LambdaFillIdentity {
T operator()(Coord_t const& coord) { return (coord.at(1) == coord.at(2) ? T(1) : T(0)); }
};
/// initializes with identity
void fill_identity() {
LambdaFillIdentity op;
elementwise_generate(op);
}
/// Lambda for fill_linear()
struct LambdaFillLinear {
Coord_t v_;
T offset_;
LambdaFillLinear(Coord_t const& _v, T _offset) : v_(_v), offset_(_offset) {}
T operator()(Coord_t const& coord) { return T(v_.template dot<int>(coord)) + offset_; }
};
/// computes elements as a linear combination of their coordinates
void fill_linear(Coord_t v, T offset = T(0)) {
LambdaFillLinear lambda(v, offset);
elementwise_generate(lambda);
}
/// computes elements as a linear combination of their coordinates
void fill_sequential(T v = T(1), T offset = T(0)) {
int const count = size().count();
for (int i = 0; i < count; ++i) {
data()[i] = T(i);
}
}
/// Returns a constant value
struct LambdaFillValue {
T value;
LambdaFillValue(T _value) : value(_value) {}
T operator()() { return value; }
};
/// fills with a value
void fill(T val = T(0)) {
LambdaFillValue op(val);
elementwise_stream(op);
}
/// Conversion from Src to T
template <typename Src>
struct LambdaAssign {
void operator()(T& a, Src b) const { a = T(b); }
};
/// copies from external data source and performs type conversion
template <typename Src>
void fill(TensorView<Src> const& tensor) {
LambdaAssign<Src> op;
elementwise_in_place(op, tensor);
}
/// Computes a norm
struct LambdaNorm {
double sum;
LambdaNorm() : sum(0) {}
void operator()(T const& element) {
double value(element);
double conj(element); // TODO - conjugates for complex
sum += value * conj;
}
};
/// Computes the norm of the matrix in double-precision
double norm() const {
LambdaNorm op;
elementwise_in_place(op);
return std::sqrt(op.sum);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -0,0 +1,61 @@
/***************************************************************************************************
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cutlass/core_io.h>
#include <cutlass/tensor_view.h>
template <typename T>
inline std::ostream& tensor_view_output(std::ostream& out, T t) {
out << t;
return out;
}
template <>
inline std::ostream& tensor_view_output<int8_t>(std::ostream& out, int8_t t) {
out << int(t);
return out;
}
template <typename T>
inline std::ostream& operator<<(std::ostream& out, cutlass::TensorView<T> const& tensor) {
for (int batch = 0; batch < tensor.size(0); ++batch) {
out << "[\n ";
for (int h = 0; h < tensor.size(1); ++h) {
for (int w = 0; w < tensor.size(2); ++w) {
for (int c = 0; c < tensor.size(3); ++c) {
out << ((c | w) ? ", " : "");
tensor_view_output(out, tensor.at(cutlass::make_Coord(batch, h, w, c)));
}
}
if (h + 1 < tensor.size(1)) {
out << " ;\n ";
}
}
out << " ]";
}
return out;
}

161
tools/util/type_traits.h Normal file
View File

@ -0,0 +1,161 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Type traits for common CUDA types
*/
#pragma once
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include "half.h"
namespace cutlass {
struct half_t;
template <typename T>
struct TypeTraits;
template <>
struct TypeTraits<int8_t> {
static cudaDataType_t const cublas_type = CUDA_R_8I;
typedef int8_t host_type;
typedef int8_t device_type;
typedef int8_t integer_type;
typedef uint8_t unsigned_type;
static inline int8_t remove_negative_zero(int8_t x) { return x; }
static inline int to_print(int8_t x) { return (int)x; }
};
template <>
struct TypeTraits<uint8_t> {
static cudaDataType_t const cublas_type = CUDA_R_8I;
typedef uint8_t host_type;
typedef uint8_t device_type;
typedef uint8_t integer_type;
typedef uint8_t unsigned_type;
static inline uint8_t remove_negative_zero(uint8_t x) { return x; }
static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; }
};
template <>
struct TypeTraits<int> {
static cudaDataType_t const cublas_type = CUDA_R_32I;
typedef int host_type;
typedef int device_type;
typedef int32_t integer_type;
typedef uint32_t unsigned_type;
static inline int32_t remove_negative_zero(int32_t x) { return x; }
static inline int to_print(int x) { return x; }
};
template <>
struct TypeTraits<unsigned> {
static cudaDataType_t const cublas_type = CUDA_R_32I;
typedef unsigned host_type;
typedef unsigned device_type;
typedef uint32_t integer_type;
typedef uint32_t unsigned_type;
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
static inline uint32_t to_print(uint32_t x) { return x; }
};
template <>
struct TypeTraits<half> {
static cudaDataType_t const cublas_type = CUDA_R_16F;
typedef half_t host_type;
typedef half device_type;
typedef int16_t integer_type;
typedef uint16_t unsigned_type;
static inline half remove_negative_zero(half x) {
integer_type h_int = reinterpret_cast<integer_type const&>(x);
if (h_int == 0x8000) {
h_int = 0;
}
x = reinterpret_cast<half const&>(h_int);
return x;
}
static inline half to_print(half x) { return x; }
};
template <>
struct TypeTraits<int64_t> {
static cudaDataType_t const cublas_type = CUDA_R_8I;
typedef int64_t host_type;
typedef int64_t device_type;
typedef int64_t integer_type;
typedef uint64_t unsigned_type;
static inline int64_t remove_negative_zero(int64_t x) { return x; }
static inline int64_t to_print(int64_t x) { return x; }
};
template <>
struct TypeTraits<uint64_t> {
static cudaDataType_t const cublas_type = CUDA_R_8I;
typedef uint64_t host_type;
typedef uint64_t device_type;
typedef uint64_t integer_type;
typedef uint64_t unsigned_type;
static inline uint64_t remove_negative_zero(uint64_t x) { return x; }
static inline uint64_t to_print(uint64_t x) { return x; }
};
template <>
struct TypeTraits<cutlass::half_t> {
static cudaDataType_t const cublas_type = CUDA_R_16F;
typedef half_t host_type;
typedef half device_type;
typedef int16_t integer_type;
typedef uint16_t unsigned_type;
static inline half_t remove_negative_zero(half_t x) {
return (x.raw() == 0x8000 ? half_t::bitcast(0) : x);
}
static inline half_t to_print(half_t x) { return x; }
};
template <>
struct TypeTraits<float> {
static cudaDataType_t const cublas_type = CUDA_R_32F;
typedef float host_type;
typedef float device_type;
typedef int32_t integer_type;
typedef uint32_t unsigned_type;
static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; }
static inline float to_print(float x) { return x; }
};
template <>
struct TypeTraits<double> {
static cudaDataType_t const cublas_type = CUDA_R_64F;
typedef double host_type;
typedef double device_type;
typedef int64_t integer_type;
typedef uint64_t unsigned_type;
static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; }
static inline double to_print(double x) { return x; }
};
} // namespace cutlass