CUTLASS v1.0 release
This commit is contained in:
254
tools/util/command_line.h
Normal file
254
tools/util/command_line.h
Normal file
@ -0,0 +1,254 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Utility for parsing command line arguments
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* command_line
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Utility for parsing command line arguments
|
||||
*/
|
||||
struct CommandLine {
|
||||
std::vector<std::string> keys;
|
||||
std::vector<std::string> values;
|
||||
std::vector<std::string> args;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
CommandLine(int argc, const char** argv) : keys(10), values(10) {
|
||||
using namespace std;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
string arg = argv[i];
|
||||
|
||||
if ((arg[0] != '-') || (arg[1] != '-')) {
|
||||
args.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
string::size_type pos;
|
||||
string key, val;
|
||||
if ((pos = arg.find('=')) == string::npos) {
|
||||
key = string(arg, 2, arg.length() - 2);
|
||||
val = "";
|
||||
} else {
|
||||
key = string(arg, 2, pos - 2);
|
||||
val = string(arg, pos + 1, arg.length() - 1);
|
||||
}
|
||||
|
||||
keys.push_back(key);
|
||||
values.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether a flag "--<flag>" is present in the commandline
|
||||
*/
|
||||
bool check_cmd_line_flag(const char* arg_name) const {
|
||||
using namespace std;
|
||||
|
||||
for (int i = 0; i < int(keys.size()); ++i) {
|
||||
if (keys[i] == string(arg_name)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns number of naked (non-flag and non-key-value) commandline parameters
|
||||
*/
|
||||
template <typename value_t>
|
||||
int num_naked_args() const {
|
||||
return args.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the commandline parameter for a given index (not including flags)
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(int index, value_t& val) const {
|
||||
using namespace std;
|
||||
if (index < args.size()) {
|
||||
istringstream str_stream(args[index]);
|
||||
str_stream >> val;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the commandline parameter for a given index (not including flags)
|
||||
*/
|
||||
void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
|
||||
val = _default;
|
||||
if (check_cmd_line_flag(arg_name)) {
|
||||
std::string value;
|
||||
get_cmd_line_argument(arg_name, value);
|
||||
|
||||
val = !(value == "0" || value == "false");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value specified for a given commandline parameter --<flag>=<value>
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(const char* arg_name,
|
||||
value_t& val,
|
||||
value_t const& _default = value_t()) const {
|
||||
using namespace std;
|
||||
|
||||
val = _default;
|
||||
|
||||
for (int i = 0; i < int(keys.size()); ++i) {
|
||||
if (keys[i] == string(arg_name)) {
|
||||
istringstream str_stream(values[i]);
|
||||
str_stream >> val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_arguments(const char* arg_name,
|
||||
std::vector<value_t>& vals,
|
||||
char sep = ',') const {
|
||||
using namespace std;
|
||||
|
||||
if (check_cmd_line_flag(arg_name)) {
|
||||
// Clear any default values
|
||||
vals.clear();
|
||||
|
||||
// Recover from multi-value string
|
||||
for (int i = 0; i < keys.size(); ++i) {
|
||||
if (keys[i] == string(arg_name)) {
|
||||
string val_string(values[i]);
|
||||
istringstream str_stream(val_string);
|
||||
string::size_type old_pos = 0;
|
||||
string::size_type new_pos = 0;
|
||||
|
||||
// Iterate <sep>-delimited values
|
||||
value_t val;
|
||||
while ((new_pos = val_string.find(sep, old_pos)) != string::npos) {
|
||||
if (new_pos != old_pos) {
|
||||
str_stream.width(new_pos - old_pos);
|
||||
str_stream >> val;
|
||||
vals.push_back(val);
|
||||
}
|
||||
|
||||
// skip over delimiter
|
||||
str_stream.ignore(1);
|
||||
old_pos = new_pos + 1;
|
||||
}
|
||||
|
||||
// Read last value
|
||||
str_stream >> val;
|
||||
vals.push_back(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the values specified for a given commandline parameter
|
||||
* --<flag>=<key:value>,<key:value>*
|
||||
*/
|
||||
void get_cmd_line_argument_pairs(const char* arg_name,
|
||||
std::vector<std::pair<std::string, std::string> >& tokens,
|
||||
char delim = ',',
|
||||
char sep = ':') const {
|
||||
if (check_cmd_line_flag(arg_name)) {
|
||||
std::string value;
|
||||
get_cmd_line_argument(arg_name, value);
|
||||
|
||||
tokenize(tokens, value, delim, sep);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of pairs parsed
|
||||
*/
|
||||
int parsed_argc() const { return (int)keys.size(); }
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Utility functions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
||||
static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
|
||||
std::string const& str,
|
||||
char delim = ',',
|
||||
char sep = ':') {
|
||||
// Home-built to avoid Boost dependency
|
||||
size_t s_idx = 0;
|
||||
size_t d_idx = std::string::npos;
|
||||
while (s_idx < str.size()) {
|
||||
d_idx = str.find_first_of(delim, s_idx);
|
||||
|
||||
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
|
||||
size_t sep_idx = str.find_first_of(sep, s_idx);
|
||||
size_t offset = 1;
|
||||
if (sep_idx == std::string::npos || sep_idx >= end_idx) {
|
||||
sep_idx = end_idx;
|
||||
offset = 0;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> item(
|
||||
str.substr(s_idx, sep_idx - s_idx),
|
||||
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
|
||||
|
||||
tokens.push_back(item);
|
||||
s_idx = end_idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
||||
static void tokenize(std::vector<std::string>& tokens,
|
||||
std::string const& str,
|
||||
char delim = ',',
|
||||
char sep = ':') {
|
||||
typedef std::vector<std::pair<std::string, std::string> > TokenVector;
|
||||
typedef TokenVector::const_iterator token_iterator;
|
||||
|
||||
std::vector<std::pair<std::string, std::string> > token_pairs;
|
||||
tokenize(token_pairs, str, delim, sep);
|
||||
for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
|
||||
tokens.push_back(tok->first);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
178
tools/util/device_memory.h
Normal file
178
tools/util/device_memory.h
Normal file
@ -0,0 +1,178 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++ interface to CUDA device memory management functions.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <cutlass/util/debug.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
#include <tools/util/exceptions.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace device_memory {
|
||||
|
||||
/******************************************************************************
|
||||
* Allocation lifetime
|
||||
******************************************************************************/
|
||||
|
||||
/// Allocate a buffer of \p count elements of type \p T on the current CUDA device
|
||||
template <typename T>
|
||||
T* allocate(size_t count = 1) {
|
||||
T* ptr = 0;
|
||||
size_t bytes = sizeof(T) * count;
|
||||
|
||||
cudaError_t cuda_error = CUDA_PERROR(cudaMalloc((void**)&ptr, bytes));
|
||||
if (cuda_error != cudaSuccess) {
|
||||
throw cuda_exception("Failed to allocate memory", cuda_error);
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Free the buffer pointed to by \p ptr
|
||||
template <typename T>
|
||||
void free(T* ptr) {
|
||||
if (ptr) {
|
||||
cudaError_t cuda_error = CUDA_PERROR(cudaFree(ptr));
|
||||
if (cuda_error != cudaSuccess) {
|
||||
throw cuda_exception("Failed to free device memory", cuda_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Data movement
|
||||
******************************************************************************/
|
||||
|
||||
template <typename T>
|
||||
void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) {
|
||||
size_t bytes = count * sizeof(T);
|
||||
|
||||
cudaError_t cuda_error = CUDA_PERROR(cudaMemcpy(dst, src, bytes, kind));
|
||||
if (cuda_error != cudaSuccess) {
|
||||
throw cuda_exception("cudaMemcpy() failed", cuda_error);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void copy_to_device(T* dst, T const* src, size_t count = 1) {
|
||||
copy(dst, src, count, cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void copy_to_host(T* dst, T const* src, size_t count = 1) {
|
||||
copy(dst, src, count, cudaMemcpyDeviceToHost);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void copy_device_to_device(T* dst, T const* src, size_t count = 1) {
|
||||
copy(dst, src, count, cudaMemcpyDeviceToDevice);
|
||||
}
|
||||
|
||||
/// Copies elements from device memory to host-side range
|
||||
template <typename OutputIterator, typename T>
|
||||
void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) {
|
||||
size_t elements = end - begin;
|
||||
copy_to_host(&*begin, device_begin, elements);
|
||||
}
|
||||
|
||||
/// Copies elements to device memory from host-side range
|
||||
template <typename T, typename InputIterator>
|
||||
void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
|
||||
size_t elements = end - begin;
|
||||
copy_to_device(device_begin, &*begin, elements);
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* "Smart" device memory allocation
|
||||
******************************************************************************/
|
||||
|
||||
/// Device allocation abstraction that tracks size and capacity
|
||||
template <typename T>
|
||||
struct allocation {
|
||||
/// Delete functor for CUDA device memory
|
||||
struct deleter {
|
||||
void operator()(T* ptr) {
|
||||
cudaError_t cuda_error = CUDA_PERROR(cudaFree(ptr));
|
||||
if (cuda_error != cudaSuccess) {
|
||||
// noexcept
|
||||
// throw cuda_exception("cudaFree() failed", cuda_error);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Number of elements of T allocated on the current CUDA device
|
||||
size_t capacity;
|
||||
|
||||
/// Smart pointer
|
||||
platform::unique_ptr<T, deleter> smart_ptr;
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
/// Constructor: allocates no memory
|
||||
allocation() : capacity(0) {}
|
||||
|
||||
/// Constructor: allocates \p capacity elements on the current CUDA device
|
||||
allocation(size_t _capacity) : smart_ptr(allocate<T>(_capacity)), capacity(_capacity) {}
|
||||
|
||||
/// Destructor
|
||||
~allocation() { reset(); }
|
||||
|
||||
/// Returns a pointer to the managed object
|
||||
T* get() const { return smart_ptr.get(); }
|
||||
|
||||
/// Releases the ownership of the managed object (without deleting) and resets capacity to zero
|
||||
T* release() {
|
||||
capacity = 0;
|
||||
return smart_ptr.release();
|
||||
}
|
||||
|
||||
/// Deletes the managed object and resets capacity to zero
|
||||
void reset() {
|
||||
capacity = 0;
|
||||
smart_ptr.reset();
|
||||
}
|
||||
|
||||
/// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
|
||||
void reset(T* _ptr, size_t _capacity) {
|
||||
smart_ptr.reset(_ptr);
|
||||
capacity = _capacity;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the object owned by *this
|
||||
T* operator->() const { return smart_ptr.get(); }
|
||||
|
||||
/// Returns the deleter object which would be used for destruction of the managed object.
|
||||
deleter& get_deleter() { return smart_ptr.get_deleter(); }
|
||||
|
||||
/// Returns the deleter object which would be used for destruction of the managed object (const)
|
||||
const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
|
||||
};
|
||||
|
||||
} // namespace device_memory
|
||||
} // namespace cutlass
|
||||
62
tools/util/exceptions.h
Normal file
62
tools/util/exceptions.h
Normal file
@ -0,0 +1,62 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++ exception semantics for CUDA error codes
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <iosfwd>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// C++ exception wrapper for CUDA \p cudaError_t
|
||||
class cuda_exception : public std::exception {
|
||||
public:
|
||||
/// Constructor
|
||||
cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {}
|
||||
|
||||
/// Returns the underlying CUDA \p cudaError_t
|
||||
cudaError_t cudaError() const { return err; }
|
||||
|
||||
protected:
|
||||
/// Explanatory string
|
||||
const char* msg;
|
||||
|
||||
/// Underlying CUDA \p cudaError_t
|
||||
cudaError_t err;
|
||||
};
|
||||
|
||||
/// Writes a cudaError_t to an output stream
|
||||
inline std::ostream& operator<<(std::ostream& out, cudaError_t result) {
|
||||
return out << cudaGetErrorString(result);
|
||||
}
|
||||
|
||||
/// Writes a cuda_exception instance to an output stream
|
||||
inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) {
|
||||
return out << e.what() << ": " << e.cudaError();
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
743
tools/util/half.h
Normal file
743
tools/util/half.h
Normal file
@ -0,0 +1,743 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Host-side implementation of half-precision float
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <utility>
|
||||
|
||||
#include <iomanip>
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// IEEE binary16 floating-point value
|
||||
class half_t {
|
||||
public:
|
||||
half_t();
|
||||
half_t(int); /// conversion from integer
|
||||
half_t(float); /// conversion from fp32
|
||||
half_t(double); /// conversion from fp64
|
||||
|
||||
static half_t bitcast(unsigned short); /// bitcast performs no conversion
|
||||
|
||||
static half_t convert(float const&); /// FP conversion - round toward nearest even
|
||||
static float convert(unsigned short const&); /// floating point conversion to fp32
|
||||
|
||||
static half_t zero() { return bitcast(0); } /// +zero
|
||||
static half_t one() { return bitcast(0x3c00); } /// one
|
||||
static half_t nan() { return bitcast(0x7fff); } /// canonical not a number
|
||||
static half_t inf() { return bitcast(0x7c00); } /// +infinity
|
||||
static half_t ninf() { return bitcast(0xfc00); } /// -infinity
|
||||
static half_t epsilon() { return bitcast(0x1000); } /// Machine epsilon
|
||||
|
||||
bool signbit() const; /// sign bit - true: negative, false: positive
|
||||
int exponent() const; /// unbiased exponent
|
||||
unsigned short mantissa() const; /// mantissa bits
|
||||
|
||||
bool isfinite() const; /// true if neither inf nor nan
|
||||
bool isinf() const; /// true if value is + or - infinity
|
||||
bool isnan() const; /// true if value is not a number
|
||||
bool isnormal() const; /// true if nonzero value is normalized
|
||||
bool iszero() const; /// true if value is + or - zero
|
||||
|
||||
bool operator==(half_t const&) const;
|
||||
bool operator!=(half_t const&) const;
|
||||
bool operator==(float const&) const;
|
||||
bool operator!=(float const&) const;
|
||||
|
||||
bool operator<(half_t const&) const;
|
||||
bool operator<=(half_t const&) const;
|
||||
bool operator>(half_t const&) const;
|
||||
bool operator>=(half_t const&) const;
|
||||
|
||||
half_t operator+(half_t const&) const;
|
||||
half_t operator-() const;
|
||||
half_t operator-(half_t const&) const;
|
||||
half_t operator*(half_t const&)const;
|
||||
half_t operator/(half_t const&) const;
|
||||
|
||||
half_t& operator+=(half_t const&);
|
||||
half_t& operator-=(half_t const&);
|
||||
half_t& operator*=(half_t const&);
|
||||
half_t& operator/=(half_t const&);
|
||||
|
||||
half_t& operator++();
|
||||
half_t& operator--();
|
||||
half_t operator++(int);
|
||||
half_t operator--(int);
|
||||
|
||||
operator bool() const; /// false if zero
|
||||
operator int() const; /// conversion to int
|
||||
operator float() const; /// conversion to fp32
|
||||
operator half() const; /// conversion to half
|
||||
|
||||
uint16_t& raw() { return x; }
|
||||
uint16_t raw() const { return x; }
|
||||
|
||||
public:
|
||||
/// data
|
||||
unsigned short x;
|
||||
};
|
||||
|
||||
/// Packed pair of half-precision elements
|
||||
class half2_t {
|
||||
public:
|
||||
half2_t();
|
||||
half2_t(half_t lo, half_t hi);
|
||||
half2_t(std::pair<float, float> const&);
|
||||
explicit half2_t(unsigned data);
|
||||
|
||||
half2_t operator+(half2_t const&) const;
|
||||
half2_t operator-(half2_t const&) const;
|
||||
half2_t operator*(half2_t const&)const;
|
||||
half2_t operator/(half2_t const&) const;
|
||||
|
||||
half2_t& operator+=(half2_t const&);
|
||||
half2_t& operator-=(half2_t const&);
|
||||
half2_t& operator*=(half2_t const&);
|
||||
half2_t& operator/=(half2_t const&);
|
||||
|
||||
float dot(half2_t const&) const; /// dot product with single-precision accumulation
|
||||
float dot(half2_t const&, float) const; /// dot product with single-precision accumulation
|
||||
|
||||
half_t doth(half2_t const&) const; /// dot product with half_t-precision accumulation
|
||||
half_t doth(half2_t const&, half_t) const; /// dot product with half_t-precision accumulation
|
||||
|
||||
unsigned packed() const;
|
||||
|
||||
operator std::pair<float, float>() const;
|
||||
operator unsigned() const;
|
||||
|
||||
public:
|
||||
half_t lo;
|
||||
half_t hi;
|
||||
};
|
||||
|
||||
template <typename Dest, typename Src>
|
||||
Dest bitcast(Src const&);
|
||||
template <>
|
||||
float bitcast<float, unsigned>(unsigned const&);
|
||||
template <>
|
||||
float bitcast<float, int>(int const&);
|
||||
template <>
|
||||
unsigned bitcast<unsigned, float>(float const&);
|
||||
template <>
|
||||
half_t bitcast<half_t, unsigned short>(unsigned short const&);
|
||||
template <>
|
||||
unsigned short bitcast<unsigned short, half_t>(half_t const&);
|
||||
template <>
|
||||
half bitcast<half, unsigned short>(unsigned short const&);
|
||||
} // namespace cutlass
|
||||
|
||||
cutlass::half_t operator+(float, cutlass::half_t const&);
|
||||
cutlass::half_t operator-(float, cutlass::half_t const&);
|
||||
cutlass::half_t operator*(float, cutlass::half_t const&);
|
||||
cutlass::half_t operator/(float, cutlass::half_t const&);
|
||||
|
||||
std::ostream& operator<<(std::ostream&, cutlass::half_t const&); /// writes a half_t
|
||||
std::istream& operator>>(std::istream&, cutlass::half_t&); /// reads a half_t
|
||||
|
||||
#ifdef BOOST_LEXICAL_CAST_INCLUDED
|
||||
namespace boost {
|
||||
|
||||
/// lexical cast from string to half_t
|
||||
template <>
|
||||
cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg);
|
||||
|
||||
/// lexical cast from half_t to string
|
||||
template <>
|
||||
std::string lexical_cast<std::string>(cutlass::half_t const& arg);
|
||||
} // namespace boost
|
||||
#endif
|
||||
|
||||
#define HLF_MANT_DIG 10
|
||||
|
||||
namespace std {
|
||||
|
||||
cutlass::half_t abs(cutlass::half_t const&); /// absolute value
|
||||
|
||||
bool isnan(cutlass::half_t const&); /// true if argument is NaN
|
||||
|
||||
bool isfinite(cutlass::half_t const&); /// true if argument is neither NaN nor infinity
|
||||
|
||||
cutlass::half_t nanh(const char* = 0); /// returns a not-a-number
|
||||
|
||||
bool isinf(cutlass::half_t const&); /// returns true if argument is infinitey (+ or -)
|
||||
|
||||
bool isnormal(
|
||||
cutlass::half_t const&); /// returns true if argument is normal (neither zero nor infinity)
|
||||
|
||||
int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating-point value
|
||||
|
||||
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
|
||||
|
||||
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
|
||||
|
||||
/// Numeric limits
|
||||
template <>
|
||||
struct numeric_limits<cutlass::half_t> {
|
||||
static bool const is_specialized = true;
|
||||
static bool const is_signed = true;
|
||||
static bool const is_integer = false;
|
||||
static bool const is_exact = false;
|
||||
static bool const has_infinity = true;
|
||||
static bool const has_quiet_NaN = true;
|
||||
static bool const has_signaling_NaN = false;
|
||||
static std::float_denorm_style const has_denorm = std::denorm_present;
|
||||
static bool const has_denorm_loss = true;
|
||||
static std::float_round_style const round_style = std::round_to_nearest;
|
||||
static bool const is_iec559 = false;
|
||||
static bool const is_bounded = true;
|
||||
static bool const is_modulo = false;
|
||||
static int const digits = HLF_MANT_DIG;
|
||||
|
||||
static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); }
|
||||
|
||||
static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); }
|
||||
|
||||
static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t epsilon() { return cutlass::half_t::epsilon(); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t round_error() { return cutlass::half_t(0.5f); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t infinity() { return cutlass::half_t::inf(); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t quiet_NaN() { return cutlass::half_t::nan(); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t signaling_NaN() { return cutlass::half_t::nan(); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::bitcast(unsigned short _x) {
|
||||
half_t h;
|
||||
h.x = _x;
|
||||
return h;
|
||||
}
|
||||
|
||||
/// FP32 -> FP16 conversion - rounds to nearest even
|
||||
inline cutlass::half_t cutlass::half_t::convert(float const& flt) {
|
||||
// software implementation rounds toward nearest even
|
||||
unsigned const& s = *reinterpret_cast<unsigned const*>(&flt);
|
||||
uint16_t sign = uint16_t((s >> 16) & 0x8000);
|
||||
int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
|
||||
int mantissa = s & 0x7fffff;
|
||||
uint16_t u = 0;
|
||||
|
||||
if ((s & 0x7fffffff) == 0) {
|
||||
// sign-preserving zero
|
||||
return cutlass::half_t::bitcast(sign);
|
||||
}
|
||||
|
||||
if (exp > 15) {
|
||||
if (exp == 128 && mantissa) {
|
||||
// not a number
|
||||
u = 0x7fff;
|
||||
} else {
|
||||
// overflow to infinity
|
||||
u = sign | 0x7c00;
|
||||
}
|
||||
return cutlass::half_t::bitcast(u);
|
||||
}
|
||||
|
||||
int sticky_bit = 0;
|
||||
|
||||
if (exp >= -14) {
|
||||
// normal fp32 to normal fp16
|
||||
exp = uint16_t(exp + uint16_t(15));
|
||||
u = uint16_t(((exp & 0x1f) << 10));
|
||||
u = uint16_t(u | (mantissa >> 13));
|
||||
} else {
|
||||
// normal single-precision to subnormal half_t-precision representation
|
||||
int rshift = (-14 - exp);
|
||||
if (rshift < 32) {
|
||||
mantissa |= (1 << 23);
|
||||
|
||||
sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
|
||||
|
||||
mantissa = (mantissa >> rshift);
|
||||
u = (uint16_t(mantissa >> 13) & 0x3ff);
|
||||
} else {
|
||||
mantissa = 0;
|
||||
u = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// round to nearest even
|
||||
int round_bit = ((mantissa >> 12) & 1);
|
||||
sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0);
|
||||
|
||||
if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
|
||||
u = uint16_t(u + 1);
|
||||
}
|
||||
|
||||
u |= sign;
|
||||
|
||||
return cutlass::half_t::bitcast(u);
|
||||
}
|
||||
|
||||
inline float cutlass::half_t::convert(unsigned short const& h) {
|
||||
int sign = ((h >> 15) & 1);
|
||||
int exp = ((h >> 10) & 0x1f);
|
||||
int mantissa = (h & 0x3ff);
|
||||
unsigned f = 0;
|
||||
|
||||
if (exp > 0 && exp < 31) {
|
||||
// normal
|
||||
exp += 112;
|
||||
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
||||
} else if (exp == 0) {
|
||||
if (mantissa) {
|
||||
// subnormal
|
||||
exp += 113;
|
||||
while ((mantissa & (1 << 10)) == 0) {
|
||||
mantissa <<= 1;
|
||||
exp--;
|
||||
}
|
||||
mantissa &= 0x3ff;
|
||||
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
||||
} else {
|
||||
// sign-preserving zero
|
||||
f = (sign << 31);
|
||||
}
|
||||
} else if (exp == 31) {
|
||||
if (mantissa) {
|
||||
f = 0x7fffffff; // not a number
|
||||
} else {
|
||||
f = (0xff << 23) | (sign << 31); // inf
|
||||
}
|
||||
}
|
||||
return *reinterpret_cast<float const*>(&f);
|
||||
}
|
||||
|
||||
inline cutlass::half_t::half_t() {}
|
||||
|
||||
inline cutlass::half_t::half_t(int i) { x = convert(float(i)).x; }
|
||||
|
||||
inline cutlass::half_t::half_t(float f) { x = convert(f).x; }
|
||||
|
||||
inline cutlass::half_t::half_t(double d) { x = convert(float(d)).x; }
|
||||
|
||||
inline bool cutlass::half_t::signbit() const { return (x >> 15) & 1; }
|
||||
|
||||
inline int cutlass::half_t::exponent() const { return ((x >> 10) & 0x1f) - 15; }
|
||||
|
||||
inline unsigned short cutlass::half_t::mantissa() const { return x & 0x3ff; }
|
||||
|
||||
inline cutlass::half_t::operator bool() const { return (x & 0x7fff) != 0; }
|
||||
|
||||
inline cutlass::half_t::operator int() const { return static_cast<int>(convert(x)); }
|
||||
|
||||
inline cutlass::half_t::operator float() const { return convert(x); }
|
||||
|
||||
inline cutlass::half_t::operator half() const { return cutlass::bitcast<half, unsigned short>(x); }
|
||||
|
||||
inline bool cutlass::half_t::operator==(cutlass::half_t const& h) const {
|
||||
if (iszero() && h.iszero()) {
|
||||
return true;
|
||||
}
|
||||
return x == h.x;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator!=(cutlass::half_t const& h) const {
|
||||
if (iszero() && h.iszero()) {
|
||||
return false;
|
||||
}
|
||||
return x != h.x;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator==(float const& b) const { return x == half_t(b).x; }
|
||||
|
||||
inline bool cutlass::half_t::operator!=(float const& b) const { return x != half_t(b).x; }
|
||||
|
||||
inline bool cutlass::half_t::iszero() const { return (x & 0x7fff) == 0; }
|
||||
|
||||
inline bool cutlass::half_t::isfinite() const { return (exponent() < 16); }
|
||||
|
||||
inline bool cutlass::half_t::isnan() const {
|
||||
int exp = ((x >> 10) & 0x1f);
|
||||
if (exp == 0x1f) {
|
||||
return (x & 0x3ff) != 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::isinf() const {
|
||||
int exp = ((x >> 10) & 0x1f);
|
||||
if (exp == 0x1f) {
|
||||
return (x & 0x3ff) == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::isnormal() const {
|
||||
int exp = exponent();
|
||||
return exp > -15 && exp < 16;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator<(half_t const& h) const {
|
||||
int sign = ((x >> 15) & 1);
|
||||
int h_sign = ((h.x >> 15) & 1);
|
||||
if (sign == h_sign) {
|
||||
return (x & 0x7fff) < (h.x & 0x7fff);
|
||||
} else if (sign) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator<=(half_t const& h) const {
|
||||
int sign = ((x >> 15) & 1);
|
||||
int h_sign = ((h.x >> 15) & 1);
|
||||
if (sign == h_sign) {
|
||||
return (x & 0x7fff) <= (h.x & 0x7fff);
|
||||
} else if (sign) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator>(half_t const& h) const {
|
||||
int sign = ((x >> 15) & 1);
|
||||
int h_sign = ((h.x >> 15) & 1);
|
||||
if (sign == h_sign) {
|
||||
return (x & 0x7fff) > (h.x & 0x7fff);
|
||||
} else if (h_sign) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool cutlass::half_t::operator>=(half_t const& h) const {
|
||||
int sign = ((x >> 15) & 1);
|
||||
int h_sign = ((h.x >> 15) & 1);
|
||||
if (sign == h_sign) {
|
||||
return (x & 0x7fff) >= (h.x & 0x7fff);
|
||||
} else if (h_sign) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator+(cutlass::half_t const& b) const {
|
||||
return cutlass::half_t(float(*this) + float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator-() const { return bitcast(x ^ 0x8000); }
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator-(cutlass::half_t const& b) const {
|
||||
return cutlass::half_t(float(*this) - float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator*(cutlass::half_t const& b) const {
|
||||
return cutlass::half_t(float(*this) * float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator/(cutlass::half_t const& b) const {
|
||||
return cutlass::half_t(float(*this) / float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator+=(cutlass::half_t const& b) {
|
||||
*this = cutlass::half_t(float(*this) + float(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator-=(cutlass::half_t const& b) {
|
||||
*this = cutlass::half_t(float(*this) - float(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator*=(cutlass::half_t const& b) {
|
||||
*this = cutlass::half_t(float(*this) * float(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator/=(cutlass::half_t const& b) {
|
||||
*this = cutlass::half_t(float(*this) / float(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator++() {
|
||||
*this = cutlass::half_t(float(*this) + 1.0f);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t& cutlass::half_t::operator--() {
|
||||
*this = cutlass::half_t(float(*this) - 1.0f);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator++(int) {
|
||||
half_t h = *this;
|
||||
*this = cutlass::half_t(float(*this) + 1.0f);
|
||||
return h;
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half_t::operator--(int) {
|
||||
half_t h = *this;
|
||||
*this = cutlass::half_t(float(*this) - 1.0f);
|
||||
return h;
|
||||
}
|
||||
|
||||
inline cutlass::half_t operator+(float a, cutlass::half_t const& b) {
|
||||
return cutlass::half_t(a + float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t operator-(float a, cutlass::half_t const& b) {
|
||||
return cutlass::half_t(a - float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t operator*(float a, cutlass::half_t const& b) {
|
||||
return cutlass::half_t(a * float(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t operator/(float a, cutlass::half_t const& b) {
|
||||
return cutlass::half_t(a / float(b));
|
||||
}
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
inline cutlass::half2_t::half2_t() {}
|
||||
|
||||
inline cutlass::half2_t::half2_t(half_t lo, half_t hi) : lo(lo), hi(hi) {}
|
||||
|
||||
inline cutlass::half2_t::half2_t(std::pair<float, float> const& p) : lo(p.first), hi(p.second) {}
|
||||
|
||||
inline cutlass::half2_t::half2_t(unsigned data)
|
||||
: lo(half_t::bitcast(uint16_t(data & 0x0ffff))),
|
||||
hi(half_t::bitcast(uint16_t((data >> 16) & 0x0ffff))) {}
|
||||
|
||||
inline cutlass::half2_t cutlass::half2_t::operator+(half2_t const& b) const {
|
||||
return half2_t(lo + b.lo, hi + b.hi);
|
||||
}
|
||||
|
||||
inline cutlass::half2_t cutlass::half2_t::operator-(half2_t const& b) const {
|
||||
return half2_t(lo - b.lo, hi - b.hi);
|
||||
}
|
||||
|
||||
inline cutlass::half2_t cutlass::half2_t::operator*(half2_t const& b) const {
|
||||
return half2_t(lo * b.lo, hi * b.hi);
|
||||
}
|
||||
|
||||
inline cutlass::half2_t cutlass::half2_t::operator/(half2_t const& b) const {
|
||||
return half2_t(lo / b.lo, hi / b.hi);
|
||||
}
|
||||
|
||||
inline cutlass::half2_t& cutlass::half2_t::operator+=(half2_t const& b) {
|
||||
lo += b.lo;
|
||||
hi += b.hi;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half2_t& cutlass::half2_t::operator-=(half2_t const& b) {
|
||||
lo -= b.lo;
|
||||
hi -= b.hi;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half2_t& cutlass::half2_t::operator*=(half2_t const& b) {
|
||||
lo *= b.lo;
|
||||
hi *= b.hi;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline cutlass::half2_t& cutlass::half2_t::operator/=(half2_t const& b) {
|
||||
lo /= b.lo;
|
||||
hi /= b.hi;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline float cutlass::half2_t::dot(half2_t const& b) const {
|
||||
return float(lo) * float(b.lo) + float(hi) * float(b.hi);
|
||||
}
|
||||
|
||||
inline float cutlass::half2_t::dot(half2_t const& b, float c) const { return c + dot(b); }
|
||||
|
||||
inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b) const {
|
||||
return cutlass::half_t(dot(b));
|
||||
}
|
||||
|
||||
inline cutlass::half_t cutlass::half2_t::doth(half2_t const& b, half_t c) const {
|
||||
return cutlass::half_t(dot(b, float(c)));
|
||||
}
|
||||
|
||||
inline cutlass::half2_t::operator std::pair<float, float>() const {
|
||||
return std::pair<float, float>(float(lo), float(hi));
|
||||
}
|
||||
|
||||
inline unsigned cutlass::half2_t::packed() const { return (lo.x | (hi.x << 16)); }
|
||||
|
||||
inline cutlass::half2_t::operator unsigned() const { return packed(); }
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
template <>
|
||||
inline float cutlass::bitcast<float, unsigned>(unsigned const& u) {
|
||||
return *reinterpret_cast<float const*>(&u);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline float cutlass::bitcast<float, int>(int const& i) {
|
||||
return *reinterpret_cast<float const*>(&i);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline unsigned cutlass::bitcast<unsigned, float>(float const& f) {
|
||||
return *reinterpret_cast<unsigned const*>(&f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline cutlass::half_t cutlass::bitcast<cutlass::half_t, unsigned short>(unsigned short const& s) {
|
||||
return *reinterpret_cast<cutlass::half_t const*>(&s);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline unsigned short cutlass::bitcast<unsigned short, cutlass::half_t>(cutlass::half_t const& h) {
|
||||
return *reinterpret_cast<unsigned short const*>(&h);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline half cutlass::bitcast<half, unsigned short>(unsigned short const& s) {
|
||||
return *reinterpret_cast<half const*>(&s);
|
||||
}
|
||||
|
||||
//
|
||||
// Lexical casts
|
||||
//
|
||||
|
||||
#ifdef BOOST_LEXICAL_CAST_INCLUDED
|
||||
namespace boost {
|
||||
template <>
|
||||
cutlass::half_t lexical_cast<cutlass::half_t>(std::string const& arg) {
|
||||
return cutlass::half_t(boost::lexical_cast<float>(arg));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string lexical_cast<std::string>(cutlass::half_t const& arg) {
|
||||
return boost::lexical_cast<std::string>(float(arg));
|
||||
}
|
||||
} // namespace boost
|
||||
#endif
|
||||
|
||||
//
|
||||
// Standard Library Operations
|
||||
//
|
||||
|
||||
// std
|
||||
namespace std {
|
||||
|
||||
inline cutlass::half_t abs(cutlass::half_t const& h) {
|
||||
return cutlass::half_t::bitcast(h.x & 0x7fff);
|
||||
}
|
||||
|
||||
inline bool isnan(cutlass::half_t const& h) { return h.isnan(); }
|
||||
|
||||
inline bool isfinite(cutlass::half_t const& h) { return h.isfinite(); }
|
||||
|
||||
inline cutlass::half_t nanh(const char*) { return cutlass::half_t::nan(); }
|
||||
|
||||
inline bool isinf(cutlass::half_t const& h) { return h.isinf(); }
|
||||
|
||||
inline bool isnormal(cutlass::half_t const& h) { return h.isnormal(); }
|
||||
|
||||
inline int fpclassify(cutlass::half_t const& h) {
|
||||
int exp = h.exponent();
|
||||
unsigned short mantissa = h.mantissa();
|
||||
if (exp < -14) {
|
||||
if (mantissa == 0) {
|
||||
return FP_ZERO;
|
||||
} else {
|
||||
return FP_SUBNORMAL;
|
||||
}
|
||||
} else if (exp > 15) {
|
||||
if (mantissa == 0) {
|
||||
return FP_INFINITE;
|
||||
} else {
|
||||
return FP_NAN;
|
||||
}
|
||||
}
|
||||
return FP_NORMAL;
|
||||
}
|
||||
|
||||
inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
|
||||
|
||||
inline cutlass::half_t sqrt(cutlass::half_t const& h) {
|
||||
return cutlass::half_t(std::sqrt(float(h)));
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
//
|
||||
// Stream interactions
|
||||
//
|
||||
|
||||
/// put to stream - half_t-precision types bitcast as unsigned shorts if base is hexadecimal
|
||||
inline std::ostream& operator<<(std::ostream& out, cutlass::half_t const& h) {
|
||||
if (out.flags() & std::ios::hex) {
|
||||
return out << h.x;
|
||||
} else {
|
||||
return out << float(h);
|
||||
}
|
||||
}
|
||||
|
||||
/// read from stream - half_t-precision types parsed as unsigned shorts if base is hexadecimal
|
||||
inline std::istream& operator>>(std::istream& in, cutlass::half_t& h) {
|
||||
if (in.flags() & std::ios::hex) {
|
||||
unsigned short u = 0;
|
||||
in >> u;
|
||||
h = cutlass::half_t::bitcast(u);
|
||||
} else {
|
||||
float f = 0;
|
||||
in >> f;
|
||||
h = cutlass::half_t(f);
|
||||
}
|
||||
return in;
|
||||
}
|
||||
362
tools/util/host_tensor.h
Normal file
362
tools/util/host_tensor.h
Normal file
@ -0,0 +1,362 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief Template class to perform computations on tensors and manage memory.
|
||||
*/
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <tools/util/device_memory.h>
|
||||
#include <tools/util/host_tensor_view.h>
|
||||
#include <tools/util/type_traits.h>
|
||||
#include <vector>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename T, bool DeviceBacked_ = true>
|
||||
class HostTensor : public HostTensorView<T> {
|
||||
public:
|
||||
/// Type used for device-side allocations
|
||||
typedef typename TypeTraits<T>::device_type DeviceType;
|
||||
|
||||
/// Base class
|
||||
typedef HostTensorView<T> Base;
|
||||
|
||||
/// If true, allocates device side memory
|
||||
static bool const DeviceBacked = DeviceBacked_;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = Base::Rank;
|
||||
|
||||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef typename Base::Offset_t Offset_t;
|
||||
|
||||
/// Tensor reference to host memory
|
||||
typedef typename Base::TensorRef_t TensorRef_t;
|
||||
|
||||
/// Tensor reference to device memory
|
||||
typedef TensorRef<DeviceType, TensorRef_t::Rank> DeviceTensorRef;
|
||||
|
||||
/// Tensor reference to constant device memory
|
||||
typedef TensorRef<DeviceType const, TensorRef_t::Rank> ConstDeviceTensorRef;
|
||||
|
||||
/// Coordinate into tensor
|
||||
typedef typename Base::Coord_t Coord_t;
|
||||
|
||||
private:
|
||||
/// Host-side memory allocation
|
||||
std::vector<T> host_;
|
||||
|
||||
/// Device-side memory
|
||||
cutlass::device_memory::allocation<DeviceType> device_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
HostTensor() {}
|
||||
|
||||
/// Constructs a Tensor_view from stride and size
|
||||
HostTensor(Coord_t const& _stride, Coord_t const& _size) { reset(_stride, _size); }
|
||||
|
||||
/// Constructs a HostTensor from size - infers strides
|
||||
HostTensor(Coord_t const& _size) {
|
||||
Coord_t _stride = make_Coord(
|
||||
_size.at(2) * _size.at(1) * _size.at(0), _size.at(1) * _size.at(0), _size.at(0), 1);
|
||||
reset(_stride, _size);
|
||||
}
|
||||
|
||||
/// Returns the number of elements needed to back vector
|
||||
size_t capacity() { return Base::capacity(); }
|
||||
|
||||
/// Returns true if the Tensor_view is bound to some memory
|
||||
bool good() const { return Base::good(); }
|
||||
|
||||
/// Updates the reference and size of a Tensor_view object
|
||||
void reset(Coord_t const& _stride, Coord_t const& _size) {
|
||||
size_t _capacity = _size.at(0) * _stride.at(0);
|
||||
|
||||
DeviceType* _device_memory = nullptr;
|
||||
if (DeviceBacked) {
|
||||
_device_memory = cutlass::device_memory::allocate<DeviceType>(_capacity);
|
||||
}
|
||||
|
||||
host_.clear();
|
||||
host_.resize(_capacity);
|
||||
device_.reset(_device_memory, _capacity);
|
||||
|
||||
Base::reset(TensorRef_t(host_.data(), _stride), _size);
|
||||
}
|
||||
|
||||
/// Initializes the host tensor as a matrix
|
||||
void resize_matrix(int rows, int columns, MatrixLayout::Kind layout) {
|
||||
bool col_major = (layout == MatrixLayout::kColumnMajor);
|
||||
int ldm = (col_major ? rows : columns);
|
||||
|
||||
Coord_t stride = make_Coord(rows * columns, col_major ? 1 : ldm, col_major ? ldm : 1, 1);
|
||||
|
||||
Coord_t size = make_Coord(1, rows, columns, 1);
|
||||
|
||||
reset(stride, size);
|
||||
}
|
||||
|
||||
/// Simplifies resizing the host tensor
|
||||
void resize(int elements) { resize_matrix(1, elements, MatrixLayout::kColumnMajor); }
|
||||
|
||||
/// Gets pointer to host data
|
||||
T const* host_data() const { return &host_[0]; }
|
||||
|
||||
/// Gets pointer to host data
|
||||
T* host_data() { return &host_[0]; }
|
||||
|
||||
/// Gets pointer to device data
|
||||
DeviceType* device_data() const { return device_.get(); }
|
||||
|
||||
/// Copies data from device to host
|
||||
void sync_host() {
|
||||
if (DeviceBacked) {
|
||||
device_memory::copy_to_host(
|
||||
host_.data(), reinterpret_cast<T const*>(device_.get()), host_.size());
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies data from host to device
|
||||
void sync_device() {
|
||||
if (DeviceBacked) {
|
||||
device_memory::copy_to_device(
|
||||
device_.get(), reinterpret_cast<DeviceType const*>(host_.data()), host_.size());
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy data from a caller-supplied device pointer
|
||||
void copy_to_host(DeviceType const *ptr_device) {
|
||||
device_memory::copy_to_host(
|
||||
host_.data(), reinterpret_cast<T const *>(ptr_device), host_.size());
|
||||
}
|
||||
|
||||
/// Copies data to a caller-supplied device pointer
|
||||
void copy_to_device(DeviceType *ptr_device) {
|
||||
device_memory::copy_to_device(
|
||||
ptr_device, reinterpret_cast<DeviceType const *>(host_.data()), host_.size());
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef_t& host_ref() { return Base::ref(); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef_t const& host_ref() const { return Base::ref(); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
DeviceTensorRef device_ref() const { return DeviceTensorRef(device_data(), stride()); }
|
||||
|
||||
/// Returns a tensor ref to constant memory on the device
|
||||
ConstDeviceTensorRef const_device_ref() const {
|
||||
return ConstDeviceTensorRef(device_data(), stride());
|
||||
}
|
||||
|
||||
/// Accesses the size
|
||||
Coord_t const& size() const { return Base::size(); }
|
||||
|
||||
/// Accesses the size
|
||||
int size(int dim) const { return Base::size(dim); }
|
||||
|
||||
/// Accesses the size
|
||||
Coord_t const& stride() const { return Base::stride(); }
|
||||
|
||||
/// Accesses the size
|
||||
int stride(int dim) const { return Base::stride(dim); }
|
||||
|
||||
/// Returns the index of an element
|
||||
Offset_t offset(Coord_t const& coord) const { return Base::offset(coord); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
bool contains(Coord_t const& coord) const { return Base::contains(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& at(Coord_t const& coord) const { return Base::at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& operator[](Coord_t const& coord) { return at(coord); }
|
||||
|
||||
/// Element-wise accessor with basic offset
|
||||
T& at(int idx) const { return Base::at(idx); }
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
TensorView<T> subview(Coord_t const& _location, Coord_t _size) const {
|
||||
return Base::subview(_location, _size);
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operation
|
||||
template <typename F>
|
||||
void elementwise_in_place(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
|
||||
Base::elementwise_in_place(op, dim, dst_offset_base);
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
||||
/// coordinate within the tensor as an argument
|
||||
template <typename F>
|
||||
void elementwise_stream(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
|
||||
Base::elementwise_stream(op, dim, dst_offset_base);
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
||||
/// coordinate within the tensor as an argument
|
||||
template <typename F>
|
||||
void elementwise_generate(F& op,
|
||||
int dim = 0,
|
||||
Offset_t dst_offset_base = 0,
|
||||
Coord_t coord = Coord_t(0)) {
|
||||
Base::elementwise_generate(op, dim, dst_offset_base, coord);
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a binary operation
|
||||
template <typename Src, typename F>
|
||||
bool elementwise_in_place(F& op,
|
||||
int dim,
|
||||
TensorView<Src> const& tensor,
|
||||
Offset_t dst_offset_base = 0,
|
||||
Offset_t src_offset_base = 0) {
|
||||
return Base::elementwise_in_place(op, dim, tensor, dst_offset_base, src_offset_base);
|
||||
}
|
||||
|
||||
/// Accumulate in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator+=(TensorView<Src> const& tensor) {
|
||||
Base::operator+=(tensor);
|
||||
sync_device();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Subtract in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator-=(TensorView<Src> const& tensor) {
|
||||
Base::operator-=(tensor);
|
||||
sync_device();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Multiply in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator*=(TensorView<Src> const& tensor) {
|
||||
Base::operator*=(tensor);
|
||||
sync_device();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Divide in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator/=(TensorView<Src> const& tensor) {
|
||||
Base::operator/=(tensor);
|
||||
sync_device();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// equality with epsilon tolerance
|
||||
bool equals(TensorView<T> const& tensor, T epsilon) const {
|
||||
return Base::equals(tensor, epsilon);
|
||||
}
|
||||
|
||||
/// equality with ulps tolerance
|
||||
bool bit_equals(TensorView<T> const& tensor, long long ulps_threshold = 0) {
|
||||
return Base::bit_equals(tensor, ulps_threshold);
|
||||
}
|
||||
|
||||
/// Computes general matrix product among select dimensions of a tensor
|
||||
/// Assumes:
|
||||
/// D: number of independent GEMMs to compute
|
||||
/// H: height of matrix
|
||||
/// W: width of matrix
|
||||
template <
|
||||
/// Data type of A matrix elements
|
||||
typename A,
|
||||
/// Data type of B matrix elements
|
||||
typename B,
|
||||
/// Data type of "compute" type (i.e. accumulator)
|
||||
typename Ctype,
|
||||
/// Data type of scale factors
|
||||
typename Stype>
|
||||
void gemm(TensorView<A> const& tensor_a, TensorView<B> const& tensor_b, Stype alpha, Stype beta) {
|
||||
Base::template gemm<A, B, Ctype, Stype>(tensor_a, tensor_b, alpha, beta);
|
||||
}
|
||||
|
||||
/// Fills with random data
|
||||
template <typename Gen>
|
||||
void fill_random(Gen generator) {
|
||||
Base::fill_random(generator);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// Procedurally assigns elements
|
||||
template <typename Gen>
|
||||
void generate(Gen generator) {
|
||||
Base::generate(generator);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// Procedurally visits elements
|
||||
template <typename Gen>
|
||||
void visit(Gen& generator) const {
|
||||
Base::visit(generator);
|
||||
}
|
||||
|
||||
/// initializes with identity
|
||||
void fill_identity() {
|
||||
Base::fill_identity();
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// computes elements as a linear combination of their coordinates
|
||||
void fill_linear(Coord_t v, T offset = T(0)) {
|
||||
Base::fill_linear(v, offset);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// computes elements as a linear combination of their coordinates
|
||||
void fill_sequential(T v = T(1), T offset = T(0)) {
|
||||
Base::fill_sequential(v, offset);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// fills with a value
|
||||
void fill(T val = T(0)) {
|
||||
Base::fill(val);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// Copies from external data source and performs type conversion
|
||||
template <typename Src>
|
||||
void fill(TensorView<Src> const& tensor) {
|
||||
Base::fill(tensor);
|
||||
sync_device();
|
||||
}
|
||||
|
||||
/// Computes the norm of the matrix in double-precision
|
||||
double norm() const { return Base::norm(); }
|
||||
};
|
||||
} // namespace cutlass
|
||||
542
tools/util/host_tensor_view.h
Normal file
542
tools/util/host_tensor_view.h
Normal file
@ -0,0 +1,542 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Host-side implementation of useful operations
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/tensor_view.h>
|
||||
#include <tools/util/type_traits.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
struct Cast {
|
||||
static inline DstType apply(SrcType src) { return static_cast<DstType>(src); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, int8_t> {
|
||||
static inline int8_t apply(float src) {
|
||||
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, uint8_t> {
|
||||
static inline uint8_t apply(float src) {
|
||||
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
class HostTensorView : public TensorView<T> {
|
||||
public:
|
||||
/// Base class
|
||||
typedef TensorView<T> TensorView_t;
|
||||
|
||||
/// Convention: depth is the first dimension
|
||||
static int const Dim_D = 0;
|
||||
|
||||
/// Convention: height is the second dimension
|
||||
static int const Dim_H = 1;
|
||||
|
||||
/// Convention: width is the third dimension
|
||||
static int const Dim_W = 2;
|
||||
|
||||
/// Convention: channel is the second dimension
|
||||
static int const Dim_C = 3;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = TensorView_t::Rank;
|
||||
|
||||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef typename TensorView_t::Offset_t Offset_t;
|
||||
|
||||
/// Reference and stride
|
||||
typedef typename TensorView_t::TensorRef_t TensorRef_t;
|
||||
|
||||
/// Coordinate into tensor
|
||||
typedef typename TensorView_t::Coord_t Coord_t;
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
HostTensorView() {}
|
||||
|
||||
/// Constructs a Tensor_view from a TensorRef and size
|
||||
HostTensorView(TensorRef_t const& _ref, Coord_t const& _size) : TensorView_t(_ref, _size) {}
|
||||
|
||||
/// Accesses the size
|
||||
Coord_t const& size() const { return TensorView_t::size(); }
|
||||
|
||||
/// Accesses the size of a specified dimension
|
||||
int size(int dim) const { return size().at(dim); }
|
||||
|
||||
/// Accesses the stride
|
||||
Coord_t const& stride() const { return TensorView_t::stride(); }
|
||||
|
||||
/// Accesses the stride along a specified dimension
|
||||
int stride(int dim) const { return stride().at(dim); }
|
||||
|
||||
/// Returns the number of scalar elements needed to store tensor
|
||||
size_t capacity() const { return size(3) * stride(3) * stride(2) * stride(1) * stride(0); }
|
||||
|
||||
/// Returns true if the Tensor_view is bound to some memory
|
||||
bool good() const { return TensorView_t::good(); }
|
||||
|
||||
/// Updates the reference and size of a TensorView object
|
||||
void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
|
||||
return TensorView_t::reset(_ref, _size);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef_t& ref() { return TensorView_t::ref(); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef_t const& ref() const { return TensorView_t::ref(); }
|
||||
|
||||
/// Assigns a tensor view
|
||||
HostTensorView& operator=(TensorView_t const& _tensor) {
|
||||
reset(_tensor.ref(), _tensor.size());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the index of an element
|
||||
Offset_t offset(Coord_t const& coord) const { return TensorView_t::offset(coord); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
bool contains(Coord_t const& coord) const { return TensorView_t::contains(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& at(Coord_t const& coord) const { return TensorView_t::at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& operator[](Coord_t const& coord) const { return at(coord); }
|
||||
|
||||
/// Accesses an element with a raw offset
|
||||
T& at(int idx) const { return TensorView_t::at(idx); }
|
||||
|
||||
/// Accesses an element with a raw offset
|
||||
T& operator[](int idx) const { return at(idx); }
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
TensorView_t subview(Coord_t const& location, Coord_t size) const {
|
||||
return TensorView_t::subview(location, size);
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operation in place
|
||||
template <typename F>
|
||||
void elementwise_in_place(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
|
||||
Offset_t dst_offset = dst_offset_base;
|
||||
|
||||
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
|
||||
if (dim < Rank - 1) {
|
||||
elementwise_in_place(op, dim + 1, dst_offset);
|
||||
} else {
|
||||
op(ref().data()[dst_offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operator with no arguments
|
||||
template <typename F>
|
||||
void elementwise_stream(F& op, int dim = 0, Offset_t dst_offset_base = 0) {
|
||||
Offset_t dst_offset = dst_offset_base;
|
||||
|
||||
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
|
||||
if (dim < Rank - 1) {
|
||||
elementwise_stream(op, dim + 1, dst_offset);
|
||||
} else {
|
||||
ref().data()[dst_offset] = op();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
||||
/// coordinate within the tensor as an argument
|
||||
template <typename F>
|
||||
void elementwise_generate(F& op,
|
||||
int dim = 0,
|
||||
Offset_t dst_offset_base = 0,
|
||||
Coord_t coord = Coord_t(0)) {
|
||||
Offset_t dst_offset = dst_offset_base;
|
||||
|
||||
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
|
||||
coord.at(dim) = idx;
|
||||
|
||||
if (dim < Rank - 1) {
|
||||
elementwise_generate(op, dim + 1, dst_offset, coord);
|
||||
} else {
|
||||
ref().data()[dst_offset] = op(coord);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
||||
/// coordinate within the tensor as an argument
|
||||
template <typename F>
|
||||
void elementwise_visit(F& op,
|
||||
int dim = 0,
|
||||
Offset_t dst_offset_base = 0,
|
||||
Coord_t coord = Coord_t(0)) const {
|
||||
Offset_t dst_offset = dst_offset_base;
|
||||
|
||||
for (int idx = 0; idx < size(dim); ++idx, dst_offset += stride(dim)) {
|
||||
coord.at(dim) = idx;
|
||||
|
||||
if (dim < Rank - 1) {
|
||||
elementwise_visit(op, dim + 1, dst_offset, coord);
|
||||
} else {
|
||||
op(ref().data()[dst_offset], coord);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recurses through all dimensions and applies a binary operation
|
||||
template <typename Src, typename F>
|
||||
bool elementwise_in_place(F& op,
|
||||
TensorView<Src> const& tensor,
|
||||
int dim = 0,
|
||||
Offset_t dst_offset_base = 0,
|
||||
Offset_t src_offset_base = 0) {
|
||||
Offset_t dst_offset = dst_offset_base;
|
||||
Offset_t src_offset = src_offset_base;
|
||||
|
||||
if (size().at(dim) != tensor.size().at(dim)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < size(dim);
|
||||
++idx, dst_offset += stride(dim), src_offset += tensor.stride(dim)) {
|
||||
if (dim < Rank - 1) {
|
||||
elementwise_in_place(op, tensor, dim + 1, dst_offset, src_offset);
|
||||
} else {
|
||||
op(data()[dst_offset], tensor.data()[src_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Src>
|
||||
struct LambdaBinaryAddition {
|
||||
void operator()(T& a, Src b) const { a += T(b); }
|
||||
};
|
||||
|
||||
template <typename Src>
|
||||
struct LambdaBinarySubtraction {
|
||||
void operator()(T& a, Src b) const { a -= T(b); }
|
||||
};
|
||||
|
||||
template <typename Src>
|
||||
struct LambdaBinaryMultiplication {
|
||||
void operator()(T& a, Src b) const { a *= T(b); }
|
||||
};
|
||||
|
||||
template <typename Src>
|
||||
struct LambdaBinaryDivision {
|
||||
void operator()(T& a, Src b) const { a /= T(b); }
|
||||
};
|
||||
|
||||
/// Accumulate in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator+=(TensorView<Src> const& tensor) {
|
||||
LambdaBinaryAddition<Src> op;
|
||||
elementwise_in_place(op, tensor);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Subtract in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator-=(TensorView<Src> const& tensor) {
|
||||
LambdaBinarySubtraction<Src> op;
|
||||
elementwise_in_place(op, tensor);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Multiply in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator*=(TensorView<Src> const& tensor) {
|
||||
LambdaBinaryMultiplication<Src> op;
|
||||
elementwise_in_place(op, tensor);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Divide in place
|
||||
template <typename Src>
|
||||
TensorView<T>& operator/=(TensorView<Src> const& tensor) {
|
||||
LambdaBinaryDivision<Src> op;
|
||||
elementwise_in_place(op, tensor);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Comparison operator
|
||||
struct EqualsOperator {
|
||||
bool equal;
|
||||
T eps;
|
||||
|
||||
EqualsOperator(T _epsilon) : equal(true), eps(_epsilon) {}
|
||||
|
||||
void operator()(T a, T b) {
|
||||
if (std::abs(T(a - b)) > eps * std::max(std::abs(a), std::abs(b))) {
|
||||
equal = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// equality with epsilon tolerance
|
||||
bool equals(TensorView<T> const& tensor, T epsilon) const {
|
||||
EqualsOperator comparison_op(epsilon);
|
||||
bool equal_size = elementwise_in_place(comparison_op, tensor);
|
||||
|
||||
return equal_size && comparison_op.equal;
|
||||
}
|
||||
|
||||
/// Compares two values which are smaller or equal to a long long int
|
||||
struct BitEqualsOperator {
|
||||
bool equal;
|
||||
long long eps;
|
||||
uint64_t index;
|
||||
|
||||
BitEqualsOperator(long long _ulps_threshold) : equal(true), eps(_ulps_threshold), index(0) {}
|
||||
|
||||
void operator()(T a, T b) {
|
||||
// convert bits to integers
|
||||
long long bits_a = 0;
|
||||
long long bits_b = 0;
|
||||
|
||||
*reinterpret_cast<T*>(&bits_a) = TypeTraits<T>::remove_negative_zero(a);
|
||||
*reinterpret_cast<T*>(&bits_b) = TypeTraits<T>::remove_negative_zero(b);
|
||||
|
||||
// compute diff
|
||||
long long ulps = bits_a - bits_b;
|
||||
if (std::abs(ulps) > eps) {
|
||||
equal = false;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
};
|
||||
|
||||
/// equality with ulps tolerance
|
||||
bool bit_equals(TensorView<T> const& tensor, long long ulps_threshold = 0) {
|
||||
BitEqualsOperator comparison_op(ulps_threshold);
|
||||
bool equal_size = elementwise_in_place(comparison_op, tensor);
|
||||
|
||||
return equal_size && comparison_op.equal;
|
||||
}
|
||||
|
||||
/// Gets naked pointer to data
|
||||
T* data() const { return TensorView_t::data(); }
|
||||
|
||||
/// Computes general matrix product among select dimensions of a tensor
|
||||
/// Assumes:
|
||||
/// D: number of independent GEMMs to compute
|
||||
/// H: height of matrix
|
||||
/// W: width of matrix
|
||||
/// C: "channels" of each element
|
||||
template <typename A, typename B, typename Ctype, typename Stype>
|
||||
void gemm(TensorView<A> const& tensor_a, TensorView<B> const& tensor_b, Stype alpha, Stype beta) {
|
||||
int const Batch = size(Dim_D);
|
||||
int const M = size(Dim_H);
|
||||
int const N = size(Dim_W);
|
||||
int const K = tensor_a.size(Dim_W);
|
||||
int const C = tensor_a.size(Dim_C);
|
||||
|
||||
// Sizes must match
|
||||
if (tensor_a.size(Dim_H) != M || tensor_b.size(Dim_W) != N || tensor_b.size(Dim_C) != C ||
|
||||
tensor_b.size(Dim_H) != K) {
|
||||
return;
|
||||
}
|
||||
|
||||
int const Mblock = 32;
|
||||
int const Nblock = 32;
|
||||
|
||||
for (int batch = 0; batch < Batch; ++batch) {
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
Ctype accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = Ctype(0);
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
for (int channel = 0; channel < C; ++channel) {
|
||||
Ctype a(tensor_a.at(make_Coord(batch, row, k_block, channel)));
|
||||
Ctype b(tensor_b.at(make_Coord(batch, k_block, col, channel)));
|
||||
|
||||
accum[i][j] += a * b;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
Coord_t coord = make_Coord(batch, row, col, 0);
|
||||
if (row < M && col < N) {
|
||||
at(coord) =
|
||||
Cast<Stype, T>::apply(alpha * Stype(accum[i][j]) + beta * Stype(at(coord)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills with random data
|
||||
template <typename Gen>
|
||||
void fill_random(Gen generator) {
|
||||
elementwise_stream(generator);
|
||||
}
|
||||
|
||||
/// Procedurally assigns elements
|
||||
template <typename Gen>
|
||||
void generate(Gen generator) {
|
||||
elementwise_generate(generator);
|
||||
}
|
||||
|
||||
/// Procedurally visits elements
|
||||
template <typename Gen>
|
||||
void visit(Gen& generator) const {
|
||||
elementwise_visit(generator);
|
||||
}
|
||||
|
||||
/// Generator to fill a tensor with the identity matrix
|
||||
struct LambdaFillIdentity {
|
||||
T operator()(Coord_t const& coord) { return (coord.at(1) == coord.at(2) ? T(1) : T(0)); }
|
||||
};
|
||||
|
||||
/// initializes with identity
|
||||
void fill_identity() {
|
||||
LambdaFillIdentity op;
|
||||
elementwise_generate(op);
|
||||
}
|
||||
|
||||
/// Lambda for fill_linear()
|
||||
struct LambdaFillLinear {
|
||||
Coord_t v_;
|
||||
T offset_;
|
||||
|
||||
LambdaFillLinear(Coord_t const& _v, T _offset) : v_(_v), offset_(_offset) {}
|
||||
|
||||
T operator()(Coord_t const& coord) { return T(v_.template dot<int>(coord)) + offset_; }
|
||||
};
|
||||
|
||||
/// computes elements as a linear combination of their coordinates
|
||||
void fill_linear(Coord_t v, T offset = T(0)) {
|
||||
LambdaFillLinear lambda(v, offset);
|
||||
elementwise_generate(lambda);
|
||||
}
|
||||
|
||||
/// computes elements as a linear combination of their coordinates
|
||||
void fill_sequential(T v = T(1), T offset = T(0)) {
|
||||
int const count = size().count();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
data()[i] = T(i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a constant value
|
||||
struct LambdaFillValue {
|
||||
T value;
|
||||
|
||||
LambdaFillValue(T _value) : value(_value) {}
|
||||
|
||||
T operator()() { return value; }
|
||||
};
|
||||
|
||||
/// fills with a value
|
||||
void fill(T val = T(0)) {
|
||||
LambdaFillValue op(val);
|
||||
elementwise_stream(op);
|
||||
}
|
||||
|
||||
/// Conversion from Src to T
|
||||
template <typename Src>
|
||||
struct LambdaAssign {
|
||||
void operator()(T& a, Src b) const { a = T(b); }
|
||||
};
|
||||
|
||||
/// copies from external data source and performs type conversion
|
||||
template <typename Src>
|
||||
void fill(TensorView<Src> const& tensor) {
|
||||
LambdaAssign<Src> op;
|
||||
elementwise_in_place(op, tensor);
|
||||
}
|
||||
|
||||
/// Computes a norm
|
||||
struct LambdaNorm {
|
||||
double sum;
|
||||
|
||||
LambdaNorm() : sum(0) {}
|
||||
|
||||
void operator()(T const& element) {
|
||||
double value(element);
|
||||
double conj(element); // TODO - conjugates for complex
|
||||
|
||||
sum += value * conj;
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes the norm of the matrix in double-precision
|
||||
double norm() const {
|
||||
LambdaNorm op;
|
||||
elementwise_in_place(op);
|
||||
|
||||
return std::sqrt(op.sum);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
61
tools/util/tensor_view_io.h
Normal file
61
tools/util/tensor_view_io.h
Normal file
@ -0,0 +1,61 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/tensor_view.h>
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream& tensor_view_output(std::ostream& out, T t) {
|
||||
out << t;
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::ostream& tensor_view_output<int8_t>(std::ostream& out, int8_t t) {
|
||||
out << int(t);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream& operator<<(std::ostream& out, cutlass::TensorView<T> const& tensor) {
|
||||
for (int batch = 0; batch < tensor.size(0); ++batch) {
|
||||
out << "[\n ";
|
||||
for (int h = 0; h < tensor.size(1); ++h) {
|
||||
for (int w = 0; w < tensor.size(2); ++w) {
|
||||
for (int c = 0; c < tensor.size(3); ++c) {
|
||||
out << ((c | w) ? ", " : "");
|
||||
tensor_view_output(out, tensor.at(cutlass::make_Coord(batch, h, w, c)));
|
||||
}
|
||||
}
|
||||
if (h + 1 < tensor.size(1)) {
|
||||
out << " ;\n ";
|
||||
}
|
||||
}
|
||||
out << " ]";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
161
tools/util/type_traits.h
Normal file
161
tools/util/type_traits.h
Normal file
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Type traits for common CUDA types
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "half.h"
|
||||
|
||||
namespace cutlass {
|
||||
struct half_t;
|
||||
|
||||
template <typename T>
|
||||
struct TypeTraits;
|
||||
|
||||
template <>
|
||||
struct TypeTraits<int8_t> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_8I;
|
||||
typedef int8_t host_type;
|
||||
typedef int8_t device_type;
|
||||
typedef int8_t integer_type;
|
||||
typedef uint8_t unsigned_type;
|
||||
static inline int8_t remove_negative_zero(int8_t x) { return x; }
|
||||
static inline int to_print(int8_t x) { return (int)x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<uint8_t> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_8I;
|
||||
typedef uint8_t host_type;
|
||||
typedef uint8_t device_type;
|
||||
typedef uint8_t integer_type;
|
||||
typedef uint8_t unsigned_type;
|
||||
static inline uint8_t remove_negative_zero(uint8_t x) { return x; }
|
||||
static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<int> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_32I;
|
||||
typedef int host_type;
|
||||
typedef int device_type;
|
||||
typedef int32_t integer_type;
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline int32_t remove_negative_zero(int32_t x) { return x; }
|
||||
static inline int to_print(int x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<unsigned> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_32I;
|
||||
typedef unsigned host_type;
|
||||
typedef unsigned device_type;
|
||||
typedef uint32_t integer_type;
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
|
||||
static inline uint32_t to_print(uint32_t x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<half> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_16F;
|
||||
typedef half_t host_type;
|
||||
typedef half device_type;
|
||||
typedef int16_t integer_type;
|
||||
typedef uint16_t unsigned_type;
|
||||
static inline half remove_negative_zero(half x) {
|
||||
integer_type h_int = reinterpret_cast<integer_type const&>(x);
|
||||
if (h_int == 0x8000) {
|
||||
h_int = 0;
|
||||
}
|
||||
x = reinterpret_cast<half const&>(h_int);
|
||||
return x;
|
||||
}
|
||||
static inline half to_print(half x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<int64_t> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_8I;
|
||||
typedef int64_t host_type;
|
||||
typedef int64_t device_type;
|
||||
typedef int64_t integer_type;
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline int64_t remove_negative_zero(int64_t x) { return x; }
|
||||
static inline int64_t to_print(int64_t x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<uint64_t> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_8I;
|
||||
typedef uint64_t host_type;
|
||||
typedef uint64_t device_type;
|
||||
typedef uint64_t integer_type;
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline uint64_t remove_negative_zero(uint64_t x) { return x; }
|
||||
static inline uint64_t to_print(uint64_t x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::half_t> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_16F;
|
||||
typedef half_t host_type;
|
||||
typedef half device_type;
|
||||
typedef int16_t integer_type;
|
||||
typedef uint16_t unsigned_type;
|
||||
static inline half_t remove_negative_zero(half_t x) {
|
||||
return (x.raw() == 0x8000 ? half_t::bitcast(0) : x);
|
||||
}
|
||||
static inline half_t to_print(half_t x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<float> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_32F;
|
||||
typedef float host_type;
|
||||
typedef float device_type;
|
||||
typedef int32_t integer_type;
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; }
|
||||
static inline float to_print(float x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<double> {
|
||||
static cudaDataType_t const cublas_type = CUDA_R_64F;
|
||||
typedef double host_type;
|
||||
typedef double device_type;
|
||||
typedef int64_t integer_type;
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; }
|
||||
static inline double to_print(double x) { return x; }
|
||||
};
|
||||
} // namespace cutlass
|
||||
Reference in New Issue
Block a user