CUTLASS 2.1 (#83)

CUTLASS 2.1 contributes:
- BLAS-style host-side API added to CUTLASS Library
- Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
- Minor enhancements and bug fixes
This commit is contained in:
Andrew Kerr
2020-04-07 13:51:25 -07:00
committed by GitHub
parent 7c0cd26d13
commit 96dab34ad9
196 changed files with 20653 additions and 1995 deletions

View File

@ -52,13 +52,15 @@ install(
#
cutlass_add_library(
cutlass_lib
SHARED
src/library.cu
cutlass_library_objs
OBJECT
src/handle.cu
src/manifest.cpp
src/operation_table.cu
src/singleton.cu
src/util.cu
)
add_library(nvidia::cutlass::library ALIAS cutlass_lib)
set_target_properties(cutlass_lib PROPERTIES EXPORT_NAME library)
file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/*.py)
@ -66,16 +68,19 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU
# auto-instantiation of CUTLASS kernels
#
# set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit.
set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
execute_process(
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generator.py
--operations all
--operations "${CUTLASS_LIBRARY_OPERATIONS}"
--build-dir ${PROJECT_BINARY_DIR}
--curr-build-dir ${CMAKE_CURRENT_BINARY_DIR}
--generator-target library
--architectures "${CUTLASS_NVCC_ARCHS_ENABLED}"
--kernels "${CUTLASS_LIBRARY_KERNELS}"
--cuda-version "${CMAKE_CUDA_COMPILER_VERSION}"
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT
OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log
@ -95,35 +100,70 @@ else()
endif()
target_include_directories(
cutlass_lib
cutlass_library_objs
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_BINARY_DIR}/include
)
set_target_properties(
cutlass_lib
PROPERTIES
OUTPUT_NAME cutlass
WINDOWS_EXPORT_ALL_SYMBOLS 1
)
target_link_libraries(
cutlass_lib
cutlass_library_objs
PUBLIC
cutlass_library_includes
)
function(cutlass_add_cutlass_library)
set(options)
set(oneValueArgs NAME TYPE EXPORT_NAME)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_add_library(
${__NAME}
${__TYPE}
EXPORT_NAME ${__EXPORT_NAME}
$<TARGET_OBJECTS:cutlass_library_objs>
)
target_link_libraries(
${__NAME}
PUBLIC
cutlass_library_includes
)
set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX})
set(OUTPUT_NAME cutlass)
if (WIN32 AND ${__TYPE} STREQUAL "STATIC")
set(OUTPUT_NAME "${OUTPUT_NAME}.static")
endif()
set_target_properties(
${__NAME}
PROPERTIES
OUTPUT_NAME ${OUTPUT_NAME}
WINDOWS_EXPORT_ALL_SYMBOLS 1
)
endfunction()
cutlass_add_cutlass_library(NAME cutlass_lib TYPE SHARED EXPORT_NAME library)
cutlass_add_cutlass_library(NAME cutlass_library_static TYPE STATIC EXPORT_NAME library_static)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(
DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION
${CMAKE_INSTALL_INCLUDEDIR}
)
install(
TARGETS cutlass_lib cutlass_library_includes
TARGETS
cutlass_lib
cutlass_library_static
cutlass_library_includes
EXPORT NvidiaCutlass
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -0,0 +1,284 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief BLAS-like handle used to launch operations on the CUDA device.
*/
#pragma once
#include <memory>
#include "cutlass/library/library.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Handle object
class Handle {
private:
/// Host workspace
static int const kHostWorkspaceSize = (4 << 10);
/// CUDA device properties
cudaDeviceProp device_;
/// CUDA stream
cudaStream_t stream_;
/// Device workspace
void *workspace_;
/// Size of device workspace in bytes
size_t workspace_size_;
/// Indicates whether scalars are host or device pointers
ScalarPointerMode scalar_pointer_mode_;
/// Pointer to the most recently executed operation
Operation const *last_operation_;
public:
/// Constructor
Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20));
/// Destructor
~Handle();
/// Move constructor
Handle(Handle && handle);
/// Move assignment operator
Handle &operator=(Handle && handle);
//
// Persistent state accessors
//
/// Returns compute capability of the selected device
int compute_capability() const;
/// Sets the current CUDA stream
void set_stream(cudaStream_t stream);
/// Gets the current CUDA stream
cudaStream_t get_stream() const;
/// Gets the device workspace size
size_t get_workspace_size() const;
/// Gets a pointer to the device workspace allocation in Global Memory
void *get_workspace() const;
/// Sets the size of device workspace, invalidating calls to get_device_workspace()
void set_workspace_size(size_t bytes);
/// Gets the scalar pointer mode
ScalarPointerMode get_scalar_pointer_mode() const;
/// Sets the scalar pointer mode
void set_scalar_pointer_mode(ScalarPointerMode mode);
/// Gets the most recently executed operation
Operation const *get_last_operation() const;
//
// Computations
//
/// Executes a GEMM computation: D <= alpha * A*B + beta * C
Status gemm(
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
void const * ptr_A, /// Pointer to A matrix in Global Memory
int lda, /// Leading dimension of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
void const * ptr_B, /// Pointer to B matrix in Global Memory
int ldb, /// Leading dimension of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrices
void const * ptr_C, /// Pointer to C matrix
int ldc, /// Leading dimension of C matrix
void * ptr_D, /// Pointer to D matrix
int ldd /// Leading dimension of D matrix
);
/// Planar complex GEMM
///
/// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel.
///
Status gemm_planar_complex(
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix
void const * ptr_A_real, /// Pointer to real part of A matrix
void const * ptr_A_imag, /// Pointer to imaginary part of A matrix
int lda_real, /// Leading dimension of real part of A matrix
int lda_imag, /// Leading dimension of imaginary part of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix
void const * ptr_B_real, /// Pointer to real part of B matrix
void const * ptr_B_imag, /// Pointer to imaginary part of B matrix
int ldb_real, /// Leading dimension of real part of B matrix
int ldb_imag, /// Leading dimension of imaginary part of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrix
void const * ptr_C_real, /// Pointer to real part of C matrix
void const * ptr_C_imag, /// Pointer to imaginary part of C matrix
int ldc_real, /// Leading dimension of real part of C matrix
int ldc_imag, /// Leading dimension of imaginary part of C matrix
void * ptr_D_real, /// Pointer to real part of D matrix
void * ptr_D_imag, /// Pointer to imaginary part of D matrix
int ldd_real, /// Leading dimension of real part of D matrix
int ldd_imag, /// Leading dimension of imaginary part of D matrix
int batch_count = 1, /// Number of batched GEMMs to execute
int64_t batch_stride_A_real = 0,
int64_t batch_stride_A_imag = 0,
int64_t batch_stride_B_real = 0,
int64_t batch_stride_B_imag = 0,
int64_t batch_stride_C_real = 0,
int64_t batch_stride_C_imag = 0,
int64_t batch_stride_D_real = 0,
int64_t batch_stride_D_imag = 0
);
/// Planar complex GEMM loading pointers from arrays in global memory
Status gemm_planar_complex_array(
int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid)
int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid)
int expected_K, /// Expected GEMM K dimension
int batch_count, /// Number of independent GEMM computations to execute
int const *M, /// Array containing the GEMM M dimension for each batch index
int const *N, /// Array containing the GEMM N dimension for each batch index
int const *K, /// Array containing the GEMM K dimension for each batch index
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix
void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices
void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices
int lda_real, /// Leading dimension of real part of A matrix
int lda_imag, /// Leading dimension of imaginary part of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix
void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices
void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices
int ldb_real, /// Leading dimension of real part of B matrix
int ldb_imag, /// Leading dimension of imaginary part of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrix
void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices
void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices
int ldc_real, /// Leading dimension of real part of C matrix
int ldc_imag, /// Leading dimension of imaginary part of C matrix
void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices
void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices
int ldd_real, /// Leading dimension of real part of D matrix
int ldd_imag /// Leading dimension of imaginary part of D matrix
);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Unique pointer storing the handle
using HandlePtr = std::unique_ptr<Handle>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -68,6 +68,10 @@ enum class LayoutTypeID {
kRowMajorInterleavedK4,
kColumnMajorInterleavedK16,
kRowMajorInterleavedK16,
kColumnMajorInterleavedK32,
kRowMajorInterleavedK32,
kColumnMajorInterleavedK64,
kRowMajorInterleavedK64,
kTensorNCHW,
kTensorNHWC,
kInvalid
@ -110,9 +114,21 @@ enum class NumericTypeID {
/// Enumeraed type describing a transformation on a complex value.
enum class ComplexTransform {
kNone,
kConjugate
kConjugate,
kInvalid
};
/// Providers
enum class Provider {
kCUTLASS,
kReferenceHost,
kReferenceDevice,
kCUBLAS,
kInvalid
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Enumeration indicating the kind of operation
enum class OperationKind {
kGemm,
@ -143,6 +159,14 @@ enum class OpcodeClassID {
kInvalid
};
enum class MathOperationID {
kMultiplyAdd,
kMultiplyAddSaturate,
kMultiplyAddComplex,
kXorPopc,
kInvalid
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Enumeration indicating what kind of GEMM operation to perform
@ -150,88 +174,20 @@ enum class GemmKind {
kGemm,
kBatched,
kArray,
kUniversal,
kPlanarComplex,
kPlanarComplexBatched,
kPlanarComplexArray,
kInvalid
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Lexical cast from string
template <typename T> T from_string(std::string const &);
/// Converts a NumericType enumerant to a string
char const *to_string(OperationKind type, bool pretty = false);
/// Parses a NumericType enumerant from a string
template <> OperationKind from_string<OperationKind>(std::string const &str);
/// Converts a NumericType enumerant to a string
char const *to_string(NumericTypeID type, bool pretty = false);
/// Parses a NumericType enumerant from a string
template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
/// Returns the size of a data type in bits
int sizeof_bits(NumericTypeID type);
/// Returns true if the numeric type is a complex data type or false if real-valued.
bool is_complex_type(NumericTypeID type);
/// Returns the real-valued type underlying a type (only different from 'type' if complex)
NumericTypeID get_real_type(NumericTypeID type);
/// Returns true if numeric type is integer
bool is_integer_type(NumericTypeID type);
/// Returns true if numeric type is signed
bool is_signed_type(NumericTypeID type);
/// Returns true if numeric type is a signed integer
bool is_signed_integer(NumericTypeID type);
/// returns true if numeric type is an unsigned integer
bool is_unsigned_integer(NumericTypeID type);
/// Returns true if numeric type is floating-point type
bool is_float_type(NumericTypeID type);
/// To string method for cutlass::Status
char const *to_string(Status status, bool pretty = false);
/// Converts a LayoutTypeID enumerant to a string
char const *to_string(LayoutTypeID layout, bool pretty = false);
/// Parses a LayoutType enumerant from a string
template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
/// Returns the rank of a layout's stride base on the LayoutTypeID
int get_layout_stride_rank(LayoutTypeID layout_id);
/// Converts a OpcodeClassID enumerant to a string
char const *to_string(OpcodeClassID type, bool pretty = false);
/// Converts a OpcodeClassID enumerant from a string
template <>
OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
/// Lexical cast from int64_t to string
std::string lexical_cast(int64_t int_value);
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
/// Casts from a signed int64 to the destination type. Returns true if successful.
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
/// Casts from an unsigned int64 to the destination type. Returns true if successful.
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
/// Mode of GEMM
enum class GemmUniversalMode {
kGemm,
kGemmSplitKParallel,
kBatched,
kArray,
kInvalid
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -246,6 +202,9 @@ struct MathInstructionDescription {
/// Classification of math instruction
OpcodeClassID opcode_class;
/// Type of math operation performed
MathOperationID math_operation;
//
// Methods
//
@ -253,9 +212,13 @@ struct MathInstructionDescription {
MathInstructionDescription(
cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(),
NumericTypeID element_accumulator = NumericTypeID::kInvalid,
OpcodeClassID opcode_class = OpcodeClassID::kInvalid
OpcodeClassID opcode_class = OpcodeClassID::kInvalid,
MathOperationID math_operation = MathOperationID::kMultiplyAdd
):
instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
instruction_shape(instruction_shape),
element_accumulator(element_accumulator),
opcode_class(opcode_class),
math_operation(math_operation) {}
};
@ -306,6 +269,9 @@ struct OperationDescription {
/// Unique identifier describing the operation
char const * name;
/// Operation provider
Provider provider;
/// Kind of operation
OperationKind kind;
@ -317,6 +283,7 @@ struct OperationDescription {
//
OperationDescription(
char const * name = "unknown",
Provider Provider = Provider::kInvalid,
OperationKind kind = OperationKind::kInvalid,
TileDescription const & tile_description = TileDescription()
):
@ -340,10 +307,11 @@ struct TensorDescription {
/// log2() of the maximum value each relevant stride may have
int log_stride_range;
//
// Methods
//
TensorDescription(
NumericTypeID element = NumericTypeID::kInvalid,
LayoutTypeID layout = LayoutTypeID::kInvalid,
@ -355,7 +323,7 @@ struct TensorDescription {
layout(layout),
alignment(alignment),
log_extent_range(log_extent_range),
log_stride_range(log_stride_range) { }
log_stride_range(log_stride_range) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -414,7 +382,7 @@ struct GemmDescription : public OperationDescription {
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Base class for all device-wide operations
/// Base class for all operations
class Operation {
public:
@ -435,7 +403,7 @@ public:
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
virtual Status run(
@ -443,6 +411,7 @@ public:
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -551,11 +520,18 @@ using GemmBatchedArguments = GemmArguments;
struct GemmArrayConfiguration {
gemm::GemmCoord problem_size;
/// Leading dimension of A matrix
int64_t lda;
int64_t const *lda;
int64_t const *ldb;
int64_t const *ldc;
int64_t const *ldd;
/// Leading dimension of B matrix
int64_t ldb;
/// Leading dimension of C matrix
int64_t ldc;
/// Leading dimension of D matrix
int64_t ldd;
int batch_count;
};
@ -580,49 +556,98 @@ struct GemmArrayArguments {
struct GemmPlanarComplexConfiguration {
GemmUniversalMode mode;
gemm::GemmCoord problem_size;
int batch_count;
int64_t lda;
int64_t ldb;
int64_t ldc;
int64_t ldd;
int64_t lda_real;
int64_t lda_imag;
int64_t imag_stride_A;
int64_t imag_stride_B;
int64_t imag_stride_C;
int64_t imag_stride_D;
int64_t ldb_real;
int64_t ldb_imag;
int64_t ldc_real;
int64_t ldc_imag;
int64_t ldd_real;
int64_t ldd_imag;
};
using GemmPlanarComplexArgments = GemmArguments;
/// Arguments for planar complex GEMMs
struct GemmPlanarComplexArguments {
void const *A_real;
void const *A_imag;
void const *B_real;
void const *B_imag;
void const *C_real;
void const *C_imag;
void *D_real;
void *D_imag;
void const *alpha;
void const *beta;
ScalarPointerMode pointer_mode;
int64_t batch_stride_A_real;
int64_t batch_stride_A_imag;
int64_t batch_stride_B_real;
int64_t batch_stride_B_imag;
int64_t batch_stride_C_real;
int64_t batch_stride_C_imag;
int64_t batch_stride_D_real;
int64_t batch_stride_D_imag;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Batched complex valued GEMM in which real and imaginary parts are separated by a stride
//
// OperationKind: Gemm
// GemmKind: Planar complex batched
//
struct GemmPlanarComplexBatchedConfiguration {
/// This is a special form of planar complex which loads pointers and problem size
/// from memory.
struct GemmPlanarComplexArrayConfiguration {
gemm::GemmCoord problem_size;
int batch_count;
int64_t lda;
int64_t ldb;
int64_t ldc;
int64_t ldd;
int64_t lda_real;
int64_t lda_imag;
int64_t imag_stride_A;
int64_t imag_stride_B;
int64_t imag_stride_C;
int64_t imag_stride_D;
int64_t ldb_real;
int64_t ldb_imag;
int64_t batched_stride_A;
int64_t batched_stride_B;
int64_t batched_stride_C;
int64_t batched_stride_D;
int64_t ldc_real;
int64_t ldc_imag;
int64_t ldd_real;
int64_t ldd_imag;
};
/// Arguments for planar complex GEMMs
struct GemmPlanarComplexArrayArguments {
int const *M;
int const *N;
int const *K;
void const * const * A_real;
void const * const * A_imag;
void const * const * B_real;
void const * const * B_imag;
void const * const * C_real;
void const * const * C_imag;
void * const * D_real;
void * const * D_imag;
void const * alpha;
void const * beta;
ScalarPointerMode pointer_mode;
};
using GemmPlanarComplexBatchedArguments = GemmArguments;
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -55,10 +55,14 @@ using OperationVector = std::vector<std::unique_ptr<Operation>>;
class Manifest {
private:
/// Operation provider
Provider provider_;
/// Global list of operations
OperationVector operations_;
public:
Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { }
/// Top-level initialization
Status initialize();

View File

@ -0,0 +1,205 @@
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
\file
\brief Defines a data structure in which a set of functionally equivalent library::Operation
instances may be queried.
*/
#pragma once
#include <iosfwd>
#include <unordered_map>
#include <algorithm>
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tuple uniquely identifying functional behavior
struct GemmFunctionalKey {
NumericTypeID element_compute;
NumericTypeID element_scalar;
NumericTypeID element_A;
LayoutTypeID layout_A;
ComplexTransform transform_A;
NumericTypeID element_B;
LayoutTypeID layout_B;
ComplexTransform transform_B;
NumericTypeID element_C;
//
// Methods
//
inline
GemmFunctionalKey(
NumericTypeID element_compute = NumericTypeID::kF32,
NumericTypeID element_scalar = NumericTypeID::kF32,
NumericTypeID element_A = NumericTypeID::kF16,
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
ComplexTransform transform_A = ComplexTransform::kNone,
NumericTypeID element_B = NumericTypeID::kF16,
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
ComplexTransform transform_B = ComplexTransform::kNone,
NumericTypeID element_C = NumericTypeID::kF16
):
element_compute(element_compute),
element_scalar(element_scalar),
element_A(element_A),
layout_A(layout_A),
transform_A(transform_A),
element_B(element_B),
layout_B(layout_B),
transform_B(transform_B),
element_C(element_C)
{ }
inline
bool operator==(GemmFunctionalKey const &rhs) const {
return
(element_compute == rhs.element_compute) &&
(element_scalar == rhs.element_scalar) &&
(element_A == rhs.element_A) &&
(layout_A == rhs.layout_A) &&
(transform_A == rhs.transform_A) &&
(element_B == rhs.element_B) &&
(layout_B == rhs.layout_B) &&
(transform_B == rhs.transform_B) &&
(element_C == rhs.element_C);
}
inline
bool operator!=(GemmFunctionalKey const &rhs) const {
return !(*this == rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Hash function for GemmFunctionalKey
struct GemmFunctionalKeyHasher {
using IntHash = std::hash<int>;
inline
static size_t rotl(size_t key, int shl) {
return (key << shl) | (key >> (sizeof(key)*8 - shl));
}
inline
size_t operator()(GemmFunctionalKey const &key) const {
IntHash hash;
return
rotl(hash(int(key.element_compute)), 2) ^
rotl(hash(int(key.element_scalar)), 3) ^
rotl(hash(int(key.element_A)), 4) ^
rotl(hash(int(key.layout_A)), 5) ^
rotl(hash(int(key.transform_A)), 6) ^
rotl(hash(int(key.element_B)), 7) ^
rotl(hash(int(key.layout_B)), 8) ^
rotl(hash(int(key.transform_B)), 9) ^
rotl(hash(int(key.element_C)), 10);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Establishes a partial ordering to search for GEMM operators
struct GemmPreferenceKey {
int compute_capability;
int alignment;
//
// Methods
//
GemmPreferenceKey(): compute_capability(), alignment() { }
GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { }
bool operator<(GemmPreferenceKey const &rhs) const {
return (compute_capability < rhs.compute_capability) ||
((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment));
}
bool operator==(GemmPreferenceKey const &rhs) const {
return compute_capability == rhs.compute_capability;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Maps minimum compute capability onto a vector of possible operations
using GemmOperationVectorMap = std::map<
GemmPreferenceKey,
std::vector<Operation const *>
>;
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
using GemmOperationFunctionalMap = std::unordered_map<
GemmFunctionalKey,
GemmOperationVectorMap,
GemmFunctionalKeyHasher
>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Table of cutlass::library::Operation instances
class OperationTable {
public:
/// Map of all operations of type kGemm and gemm_kind of type kGemm
GemmOperationFunctionalMap gemm_operations;
/// Map of all operations of type kGemm and gemm_kind of type kPlanarComplex
GemmOperationFunctionalMap gemm_planar_complex_operations;
/// Map of all operations of type kGemm and gemm_kind of type kPlanarComplexArray
GemmOperationFunctionalMap gemm_planar_complex_array_operations;
public:
void append(Manifest const &manifest);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k);

View File

@ -0,0 +1,62 @@
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/operation_table.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Singleton instance stores a Manifest and Operation table
class Singleton {
public:
/// Manifest object
Manifest manifest;
/// Operation table referencing the Manifest
OperationTable operation_table;
public:
Singleton();
static Singleton const &get();
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,138 @@
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Utilities accompanying the CUTLASS library for interacting with Library types.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Lexical cast from string
template <typename T> T from_string(std::string const &);
/// Converts a Provider enumerant to a string
char const *to_string(Provider provider, bool pretty = false);
/// Parses a Provider enumerant from a string
template <> Provider from_string<Provider>(std::string const &str);
/// Converts a NumericType enumerant to a string
char const *to_string(OperationKind type, bool pretty = false);
/// Parses a NumericType enumerant from a string
template <> OperationKind from_string<OperationKind>(std::string const &str);
/// Converts a NumericType enumerant to a string
char const *to_string(NumericTypeID type, bool pretty = false);
/// Parses a NumericType enumerant from a string
template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
/// Returns the size of a data type in bits
int sizeof_bits(NumericTypeID type);
/// Returns true if the numeric type is a complex data type or false if real-valued.
bool is_complex_type(NumericTypeID type);
/// Returns the real-valued type underlying a type (only different from 'type' if complex)
NumericTypeID get_real_type(NumericTypeID type);
/// Returns true if numeric type is integer
bool is_integer_type(NumericTypeID type);
/// Returns true if numeric type is signed
bool is_signed_type(NumericTypeID type);
/// Returns true if numeric type is a signed integer
bool is_signed_integer(NumericTypeID type);
/// returns true if numeric type is an unsigned integer
bool is_unsigned_integer(NumericTypeID type);
/// Returns true if numeric type is floating-point type
bool is_float_type(NumericTypeID type);
/// To string method for cutlass::Status
char const *to_string(Status status, bool pretty = false);
/// Converts a LayoutTypeID enumerant to a string
char const *to_string(LayoutTypeID layout, bool pretty = false);
/// Parses a LayoutType enumerant from a string
template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
/// Returns the rank of a layout's stride base on the LayoutTypeID
int get_layout_stride_rank(LayoutTypeID layout_id);
/// Converts a OpcodeClassID enumerant to a string
char const *to_string(OpcodeClassID type, bool pretty = false);
/// Converts a OpcodeClassID enumerant from a string
template <>
OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
/// Converts a ComplexTransform enumerant to a string
char const *to_string(ComplexTransform type, bool pretty = false);
/// Converts a ComplexTransform enumerant from a string
template <>
ComplexTransform from_string<ComplexTransform>(std::string const &str);
/// Lexical cast from int64_t to string
std::string lexical_cast(int64_t int_value);
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
/// Casts from a signed int64 to the destination type. Returns true if successful.
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
/// Casts from an unsigned int64 to the destination type. Returns true if successful.
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -22,7 +22,9 @@ from library import *
#
class GemmOperation:
#
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue):
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Cohort):
self.operation_kind = OperationKind.Gemm
self.arch = arch
self.tile_description = tile_description
@ -31,29 +33,75 @@ class GemmOperation:
self.B = B
self.C = C
self.element_epilogue = element_epilogue
self.epilogue_functor = epilogue_functor
self.swizzling_functor = swizzling_functor
#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
]
return self.tile_description.math_instruction.math_operation in complex_operators
#
def is_planar_complex(self):
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
if self.is_complex():
return get_complex_from_real(accum)
return accum
#
def short_math_name(self):
return ShortDataTypeNames[self.accumulator_type()]
#
def core_name(self):
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
inst_shape = ''
inst_operation = ''
intermediate_type = ''
math_operations_map = {
MathOperation.xor_popc: 'xor',
}
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
else:
inst_shape = ''
return "%s%s%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], inst_shape, GemmKindNames[self.gemm_kind])
math_op = self.tile_description.math_instruction.math_operation
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
inst_shape += math_op_string
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
#
def extended_name(self):
''' Append data types if they differ from compute type. '''
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
if self.is_complex():
extended_name = "${core_name}"
else:
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
'element_a': DataTypeNames[self.A.element],
@ -63,28 +111,32 @@ class GemmOperation:
return extended_name
#
def layout_name(self):
if self.is_complex() or self.is_planar_complex():
return "%s%s" % (
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
)
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
#
def procedural_name(self):
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
if self.tile_description.stages > 2:
threadblock = "%dx%d_%dx%d" % (
self.tile_description.threadblock_shape[0],
self.tile_description.threadblock_shape[1],
self.tile_description.threadblock_shape[2],
self.tile_description.stages
)
else:
threadblock = "%dx%d" % (self.tile_description.threadblock_shape[0], self.tile_description.threadblock_shape[1])
threadblock = self.tile_description.procedural_name()
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
return SubstituteTemplate(
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}",
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
{
'opcode_class': opcode_class_name,
'extended_name': self.extended_name(),
'threadblock': threadblock,
'layout': "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]),
'layout': self.layout_name(),
'alignment': "%d" % self.A.alignment,
}
)
@ -104,7 +156,7 @@ class EmitGemmInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self):
self.template = """
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
${element_a}, ${layout_a},
@ -116,14 +168,45 @@ class EmitGemmInstance:
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
cutlass::epilogue::thread::LinearCombination<
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
${stages}
${swizzling_functor},
${stages},
${align_a},
${align_b},
false,
${math_operation}
${residual}
>;
"""
self.gemm_complex_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${transform_a},
${transform_b},
${math_operation}
${residual}
>;
"""
@ -135,6 +218,8 @@ class EmitGemmInstance:
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = ''
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
@ -143,7 +228,7 @@ class EmitGemmInstance:
'layout_b': LayoutTag[operation.B.layout],
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
@ -157,57 +242,72 @@ class EmitGemmInstance:
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'stages': str(operation.tile_description.stages)
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
'transform_a': ComplexTransformTag[operation.A.complex_transform],
'transform_b': ComplexTransformTag[operation.B.complex_transform],
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'residual': residual
}
return SubstituteTemplate(self.template, values)
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
return SubstituteTemplate(template, values)
###################################################################################################
#
class EmitGemmBatchedInstance:
class EmitGemmPlanarComplexInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self):
self.template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::GemmBatched<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
${element_c}, cutlass::layout::RowMajor,
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
cutlass::epilogue::thread::LinearCombination<
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
${element_c},
${epilogue_vector_length},
${alignment_c},
${element_accumulator},
${element_epilogue}
>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
${stages},
${align_a},
${align_b}
>;
${math_operator}
>::GemmKernel;
struct ${operation_name} : public Operation_${operation_name} { };
"""
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
#warp_shape[2] = operation.tile_description.math_instruction.instruction_shape[2]
warp_shape[2] = operation.tile_description.threadblock_shape[2]
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout[operation.A.layout]
transposed_layout_B = TransposedLayout[operation.B.layout]
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
'layout_a': LayoutTag[operation.A.layout],
'element_b': DataTypeTag[operation.B.element],
'layout_b': LayoutTag[operation.B.layout],
'element_a': DataTypeTag[operation.B.element],
'layout_a': LayoutTag[transposed_layout_B],
'transform_a': ComplexTransformTag[operation.B.complex_transform],
'alignment_a': str(operation.B.alignment),
'element_b': DataTypeTag[operation.A.element],
'layout_b': LayoutTag[transposed_layout_A],
'transform_b': ComplexTransformTag[operation.A.complex_transform],
'alignment_b': str(operation.A.alignment),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
@ -222,139 +322,89 @@ class EmitGemmBatchedInstance:
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'epilogue_vector_length': str(epilogue_vector_length),
'alignment_c': str(operation.C.alignment),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
'math_operator': 'cutlass::arch::OpMultiplyAdd'
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
# Generator functions for all layouts
#
class EmitGemmPlanarComplexArrayInstance:
''' Responsible for emitting a CUTLASS template definition'''
def __init__(self):
self.template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
${element_c}, cutlass::layout::RowMajor,
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
${element_c},
${alignment_c},
${element_accumulator},
${element_epilogue}
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
${stages},
${math_operator}
>::GemmArrayKernel;
struct ${operation_name} : public Operation_${operation_name} { };
"""
def emit(self, operation):
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout[operation.A.layout]
transposed_layout_B = TransposedLayout[operation.B.layout]
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.B.element],
'layout_a': LayoutTag[transposed_layout_B],
'transform_a': ComplexTransformTag[operation.B.complex_transform],
'alignment_a': str(operation.B.alignment),
'element_b': DataTypeTag[operation.A.element],
'layout_b': LayoutTag[transposed_layout_A],
'transform_b': ComplexTransformTag[operation.A.complex_transform],
'alignment_b': str(operation.A.alignment),
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[operation.C.layout],
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
'alignment_c': str(operation.C.alignment),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'stages': str(operation.tile_description.stages),
'math_operator': 'cutlass::arch::OpMultiplyAdd'
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
def GenerateGemmSimt(gemm_kind, manifest, tile_descriptions, min_cc):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
# for each tile configuration, emit a GEMM
for tile in tile_descriptions:
for layout in layouts:
A = TensorDescription(tile.math_instruction.element_a, layout[0], 1)
B = TensorDescription(tile.math_instruction.element_b, layout[1], 1)
C = TensorDescription(tile.math_instruction.element_accumulator, layout[2], 1)
manifest.append(GemmOperation(gemm_kind, 50, tile, A, B, C, tile.math_instruction.element_accumulator))
#
def GenerateGemmTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]):
# Canonical matrix layouts
canonical_layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
# Interleaved matrix layouts
interleaved_layouts = {
8: [
#(LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
],
4: [
#(LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
}
# for each tile configuration, emit a GEMM
for align in minimum_alignment:
for tile in tile_descriptions:
min_input_size = min(DataTypeSize[tile.math_instruction.element_a], DataTypeSize[tile.math_instruction.element_a])
# If the data type is large enough, use canonical layouts.
if min_input_size >= 16:
layouts = canonical_layouts
else:
layouts = interleaved_layouts[min_input_size]
for layout in layouts:
#
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
else [tile.math_instruction.element_accumulator,]
align_a = align // DataTypeSize[tile.math_instruction.element_a]
align_b = align // DataTypeSize[tile.math_instruction.element_b]
for output_type in output_types:
rows_per_warp = 8 // tile.warp_count[1]
align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32)
A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a)
B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b)
C = TensorDescription(output_type, layout[2], max(1, align_c))
element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \
else tile.math_instruction.element_accumulator
manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue))
#
def GenerateGemmWmmaTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]):
# Wmma supported matrix layouts
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
]
# for each tile configuration, emit a GEMM
for align in minimum_alignment:
for tile in tile_descriptions:
for layout in layouts:
#
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
else [tile.math_instruction.element_accumulator,]
align_a = align // DataTypeSize[tile.math_instruction.element_a]
align_b = align // DataTypeSize[tile.math_instruction.element_b]
for output_type in output_types:
rows_per_warp = 8 // tile.warp_count[1]
align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32)
A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a)
B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b)
C = TensorDescription(output_type, layout[2], max(1, align_c))
element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \
else tile.math_instruction.element_accumulator
manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue))
###################################################################################################
#
@ -369,21 +419,40 @@ class EmitGemmConfigurationLibrary:
self.instance_emitter = {
GemmKind.Gemm: EmitGemmInstance,
GemmKind.Batched: EmitGemmBatchedInstance
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
}
self.gemm_kind_wrappers = {
GemmKind.Gemm: 'GemmOperation',
GemmKind.Batched: 'GemmBatchedOperation',
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
self.instance_template = """
self.instance_template = {
GemmKind.Gemm: """
${compile_guard_start}
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
${compile_guard_end}
""",
GemmKind.PlanarComplex: """
${compile_guard_start}
manifest.append(new ${gemm_kind}<
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
>("${operation_name}"));
${compile_guard_end}
""",
GemmKind.PlanarComplexArray: """
${compile_guard_start}
manifest.append(new ${gemm_kind}<
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
>("${operation_name}"));
${compile_guard_end}
"""
}
self.header_template = """
/*
Generated by gemm_operation.py - Do not edit.
@ -398,6 +467,14 @@ ${compile_guard_end}
#include "library_internal.h"
#include "gemm_operation.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
"""
self.initialize_function_template = """
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
@ -421,9 +498,11 @@ void initialize_${configuration_name}(Manifest &manifest) {
def __enter__(self):
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(SubstituteTemplate(self.header_template, {
'configuration_name': self.configuration_name
}))
self.configuration_file.write(self.header_template)
self.instance_definitions = []
self.instance_wrappers = []
self.operations = []
return self
@ -431,8 +510,10 @@ void initialize_${configuration_name}(Manifest &manifest) {
emitter = self.instance_emitter[operation.gemm_kind]()
self.operations.append(operation)
self.configuration_file.write(emitter.emit(operation))
self.configuration_file.write(SubstituteTemplate(self.instance_template, {
self.instance_definitions.append(emitter.emit(operation))
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
'configuration_name': self.configuration_name,
'operation_name': operation.procedural_name(),
'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
@ -443,6 +524,19 @@ void initialize_${configuration_name}(Manifest &manifest) {
}))
def __exit__(self, exception_type, exception_value, traceback):
# Write instance definitions in top-level namespace
for instance_definition in self.instance_definitions:
self.configuration_file.write(instance_definition)
# Add wrapper objects within initialize() function
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
'configuration_name': self.configuration_name
}))
for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()

File diff suppressed because it is too large Load Diff

View File

@ -153,6 +153,68 @@ DataTypeSize = {
###################################################################################################
#
class ComplexTransform(enum.Enum):
none = enum.auto()
conj = enum.auto()
#
ComplexTransformTag = {
ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
}
#
RealComplexBijection = [
(DataType.f16, DataType.cf16),
(DataType.f32, DataType.cf32),
(DataType.f64, DataType.cf64),
]
#
def is_complex(data_type):
for r, c in RealComplexBijection:
if data_type == c:
return True
return False
#
def get_complex_from_real(real_type):
for r, c in RealComplexBijection:
if real_type == r:
return c
return DataType.invalid
#
def get_real_from_complex(complex_type):
for r, c in RealComplexBijection:
if complex_type == c:
return r
return DataType.invalid
#
class ComplexMultiplyOp(enum.Enum):
multiply_add = enum.auto()
gaussian = enum.auto()
###################################################################################################
#
class MathOperation(enum.Enum):
multiply_add = enum.auto()
multiply_add_saturate = enum.auto()
xor_popc = enum.auto()
multiply_add_complex = enum.auto()
#
MathOperationTag = {
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
}
###################################################################################################
#
class LayoutType(enum.Enum):
ColumnMajor = enum.auto()
@ -182,6 +244,17 @@ LayoutTag = {
LayoutType.TensorNCxHW64: 'cutlass::layout::TensorNCxHW64'
}
#
TransposedLayout = {
LayoutType.ColumnMajor: LayoutType.RowMajor,
LayoutType.RowMajor: LayoutType.ColumnMajor,
LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
LayoutType.TensorNHWC: LayoutType.TensorNHWC
}
#
ShortLayoutTypeNames = {
LayoutType.ColumnMajor: 'n',
@ -197,6 +270,14 @@ ShortLayoutTypeNames = {
LayoutType.TensorNCxHW64: 'ncxhw64'
}
#
ShortComplexLayoutNames = {
(LayoutType.ColumnMajor, ComplexTransform.none): 'n',
(LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
(LayoutType.RowMajor, ComplexTransform.none): 't',
(LayoutType.RowMajor, ComplexTransform.conj): 'h'
}
###################################################################################################
#
@ -244,9 +325,15 @@ ArchitectureNames = {
#
def SubstituteTemplate(template, values):
text = template
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
text = re.sub(regex, value, text)
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
###################################################################################################
@ -256,28 +343,52 @@ class GemmKind(enum.Enum):
Gemm = enum.auto()
Batched = enum.auto()
Array = enum.auto()
Universal = enum.auto()
PlanarComplex = enum.auto()
PlanarComplexBatched = enum.auto()
PlanarComplexArray = enum.auto()
#
GemmKindNames = {
GemmKind.Gemm: "gemm",
GemmKind.Batched: "gemm_batched",
GemmKind.Array: "gemm_array",
GemmKind.Universal: "gemm_universal",
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexBatched: "gemm_planar_complex_batched",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
}
#
class EpilogueFunctor(enum.Enum):
LinearCombination = enum.auto()
LinearCombinationClamp = enum.auto()
#
EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
}
#
class SwizzlingFunctor(enum.Enum):
Cohort = enum.auto()
Identity = enum.auto()
#
SwizzlingFunctorTag = {
SwizzlingFunctor.Cohort: 'cutlass::gemm::threadblock::GemmCohortThreadblockSwizzle<${layout_a}, ${layout_b}>',
SwizzlingFunctor.Identity: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle',
}
###################################################################################################
#
class MathInstruction:
def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class):
def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
#
@ -292,16 +403,14 @@ class TileDescription:
self.maximum_compute_capability = max_compute
def procedural_name(self):
if self.stages == 2:
return "%dx%dx%d" % self.threadblock_shape
elif self.stages > 2:
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
#
class TensorDescription:
def __init__(self, element, layout, alignment = 1):
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
self.element = element
self.layout = layout
self.alignment = alignment
self.complex_transform = complex_transform
###################################################################################################

View File

@ -114,6 +114,16 @@ class Manifest:
self.args = args
self.compute_capabilities = [int(x) for x in args.architectures.split(';')]
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.kernels == 'all':
self.kernel_names = []
else:
@ -142,6 +152,16 @@ void initialize_all(Manifest &manifest) {
} // namespace cutlass
'''
#
def _filter_string_matches(self, filter_string, haystack):
''' Returns true if all substrings appear in the haystack in order'''
substrings = filter_string.split('*')
for sub in substrings:
idx = haystack.find(sub)
if idx < 0:
return False
haystack = haystack[idx + len(sub):]
return True
#
def filter(self, operation):
@ -159,6 +179,9 @@ void initialize_all(Manifest &manifest) {
if not enabled:
return False
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
return False
# eliminate duplicates
if operation.procedural_name() in self.operations_by_name.keys():
return False
@ -168,11 +191,10 @@ void initialize_all(Manifest &manifest) {
name = operation.procedural_name()
enabled = False
for name_substr in self.kernel_names:
if name_substr in name:
if self._filter_string_matches(name_substr, name):
enabled = True
break
# todo: filter based on operation kind
# todo: filter based on compute data type
return enabled
#
@ -255,10 +277,11 @@ void initialize_all(Manifest &manifest) {
manifest_path = os.path.join(generated_path, "manifest.cmake")
with open(manifest_path, "w") as manifest_file:
target_name = 'cutlass_lib'
target_name = 'cutlass_library_objs'
target_text = SubstituteTemplate("""cutlass_target_sources(
${target_name}
BATCH_SOURCES ON
PRIVATE
""", { 'target_name': target_name})

View File

@ -29,8 +29,13 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/gemm/device/gemm_batched.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/library/library.h"
#include "library_internal.h"
@ -68,8 +73,10 @@ public:
GemmOperationBase(char const *name = "unknown_gemm") {
description_.name = name;
description_.provider = Provider::kCUTLASS;
description_.kind = OperationKind::kGemm;
description_.gemm_kind = GemmKind::kGemm;
description_.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM,
Operator::ThreadblockShape::kN,
@ -93,22 +100,23 @@ public:
description_.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
description_.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::Operator>::kId;
description_.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag>::kMin;
description_.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag>::kMax;
description_.gemm_kind = GemmKind::kGemm;
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.split_k_mode = Operator::kSplitKSerial ? SplitKMode::kSerial : SplitKMode::kNone;
description_.transform_A = ComplexTransform::kNone;
description_.transform_B = ComplexTransform::kNone;
description_.split_k_mode = SplitKMode::kNone;
description_.transform_A = ComplexTransformMap<Operator::kTransformA>::kId;
description_.transform_B = ComplexTransformMap<Operator::kTransformB>::kId;
}
/// Returns the description of the GEMM operation
@ -294,8 +302,24 @@ public:
return op->run(stream);
}
};
void print_operator_args(OperatorArguments &operator_args) const {
#if 0
std::cout << "GemmOperation::OperatorArguments" << std::endl;
std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl;
std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl;
std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl;
std::cout << " beta: " << operator_args.epilogue.beta << std::endl;
std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl;
std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl;
std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl;
std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl;
std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl;
std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl;
std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl;
#endif
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -360,6 +384,7 @@ protected:
*static_cast<ElementCompute const *>(arguments->alpha),
*static_cast<ElementCompute const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
@ -491,6 +516,593 @@ public:
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmArrayOperation : public GemmOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
protected:
///
GemmDescription description_;
public:
/// Constructor
GemmArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
description_.gemm_kind = GemmKind::kArray;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args,
GemmArrayConfiguration const *configuration) {
operator_args.problem_size = configuration->problem_size;
operator_args.batch_count = configuration->batch_count;
return Status::kSuccess;
}
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
GemmArrayArguments const *arguments) {
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
typename Operator::EpilogueOutputOp::Params params(
*static_cast<ElementCompute const *>(arguments->alpha),
*static_cast<ElementCompute const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
typename Operator::EpilogueOutputOp::Params params(
static_cast<ElementCompute const *>(arguments->alpha),
static_cast<ElementCompute const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
public:
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
/// Returns success if the operation can proceed
virtual Status can_implement(
void const *configuration_ptr,
void const *arguments_ptr) const {
GemmArrayConfiguration const *configuration =
static_cast<GemmArrayConfiguration const *>(configuration_ptr);
GemmArrayArguments const *arguments =
static_cast<GemmArrayArguments const *>(arguments_ptr);
OperatorArguments args;
Status status = construct_arguments_(args, configuration);
if (status != Status::kSuccess) {
return status;
}
status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
return Operator::can_implement(args);
}
/// Gets the host-side workspace
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(Operator);
}
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmArrayConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return 0;
}
return Operator::get_workspace_size(args);
}
/// Initializes the workspace
virtual Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmArrayConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = new (host_workspace) Operator;
return op->initialize(args, device_workspace, stream);
}
/// Runs the kernel
virtual Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = update_arguments_(
args,
static_cast<GemmArrayArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
status = op->update(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op->run(stream);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmPlanarComplexOperation : public GemmOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
public:
/// Constructor
GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
this->description_.gemm_kind = GemmKind::kPlanarComplex;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args,
GemmPlanarComplexConfiguration const *configuration) {
operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched;
operator_args.problem_size = configuration->problem_size;
operator_args.batch_count = configuration->batch_count;
operator_args.lda_real = int(configuration->lda_real);
operator_args.lda_imag = int(configuration->lda_imag);
operator_args.ldb_real = int(configuration->ldb_real);
operator_args.ldb_imag = int(configuration->ldb_imag);
operator_args.ldc_real = int(configuration->ldc_real);
operator_args.ldc_imag = int(configuration->ldc_imag);
operator_args.ldd_real = int(configuration->ldd_real);
operator_args.ldd_imag = int(configuration->ldd_imag);
return Status::kSuccess;
}
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
GemmPlanarComplexArguments const *arguments) {
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
typename Operator::EpilogueOutputOp::Params params(
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
typename Operator::EpilogueOutputOp::Params params(
static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else {
return Status::kErrorInvalidProblem;
}
// update arguments
operator_args.ptr_A_real = arguments->A_real;
operator_args.ptr_A_imag = arguments->A_imag;
operator_args.ptr_B_real = arguments->B_real;
operator_args.ptr_B_imag = arguments->B_imag;
operator_args.ptr_C_real = arguments->C_real;
operator_args.ptr_C_imag = arguments->C_imag;
operator_args.ptr_D_real = arguments->D_real;
operator_args.ptr_D_imag = arguments->D_imag;
operator_args.batch_stride_A = arguments->batch_stride_A_real;
operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag;
operator_args.batch_stride_B = arguments->batch_stride_B_real;
operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag;
operator_args.batch_stride_C = arguments->batch_stride_C_real;
operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag;
operator_args.batch_stride_D = arguments->batch_stride_D_real;
operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag;
return Status::kSuccess;
}
public:
/// Returns success if the operation can proceed
virtual Status can_implement(
void const *configuration_ptr,
void const *arguments_ptr) const {
GemmPlanarComplexConfiguration const *configuration =
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr);
GemmPlanarComplexArguments const *arguments =
static_cast<GemmPlanarComplexArguments const *>(arguments_ptr);
OperatorArguments args;
Status status = construct_arguments_(args, configuration);
if (status != Status::kSuccess) {
return status;
}
status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
return Operator::can_implement(args);
}
/// Gets the host-side workspace
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(Operator);
}
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
virtual Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = new (host_workspace) Operator;
status = op->initialize(args, device_workspace, stream);
return status;
}
/// Runs the kernel
virtual Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = update_arguments_(
args,
static_cast<GemmPlanarComplexArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
status = op->update(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
status = op->run(stream);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmPlanarComplexArrayOperation : public GemmOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
public:
/// Constructor
GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
this->description_.gemm_kind = GemmKind::kPlanarComplexArray;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args,
GemmPlanarComplexArrayConfiguration const *configuration) {
operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray;
operator_args.problem_size = configuration->problem_size;
operator_args.batch_count = configuration->batch_count;
operator_args.lda_real = int(configuration->lda_real);
operator_args.lda_imag = int(configuration->lda_imag);
operator_args.ldb_real = int(configuration->ldb_real);
operator_args.ldb_imag = int(configuration->ldb_imag);
operator_args.ldc_real = int(configuration->ldc_real);
operator_args.ldc_imag = int(configuration->ldc_imag);
operator_args.ldd_real = int(configuration->ldd_real);
operator_args.ldd_imag = int(configuration->ldd_imag);
return Status::kSuccess;
}
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
GemmPlanarComplexArrayArguments const *arguments) {
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
typename Operator::EpilogueOutputOp::Params params(
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
typename Operator::EpilogueOutputOp::Params params(
static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
);
operator_args.epilogue = params;
}
else {
return Status::kErrorInvalidProblem;
}
// update arguments
operator_args.ptr_A_real = arguments->A_real;
operator_args.ptr_A_imag = arguments->A_imag;
operator_args.ptr_B_real = arguments->B_real;
operator_args.ptr_B_imag = arguments->B_imag;
operator_args.ptr_C_real = arguments->C_real;
operator_args.ptr_C_imag = arguments->C_imag;
operator_args.ptr_D_real = arguments->D_real;
operator_args.ptr_D_imag = arguments->D_imag;
operator_args.ptr_M = arguments->M;
operator_args.ptr_N = arguments->N;
operator_args.ptr_K = arguments->K;
return Status::kSuccess;
}
public:
/// Returns success if the operation can proceed
virtual Status can_implement(
void const *configuration_ptr,
void const *arguments_ptr) const {
GemmPlanarComplexArrayConfiguration const *configuration =
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr);
GemmPlanarComplexArrayArguments const *arguments =
static_cast<GemmPlanarComplexArrayArguments const *>(arguments_ptr);
OperatorArguments args;
Status status = construct_arguments_(args, configuration);
if (status != Status::kSuccess) {
return status;
}
status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
return Operator::can_implement(args);
}
/// Gets the host-side workspace
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(Operator);
}
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
virtual Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = construct_arguments_(
args,
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = new (host_workspace) Operator;
status = op->initialize(args, device_workspace, stream);
return status;
}
/// Runs the kernel
virtual Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
OperatorArguments args;
Status status = update_arguments_(
args,
static_cast<GemmPlanarComplexArrayArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
status = op->update(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
status = op->run(stream);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library

845
tools/library/src/handle.cu Normal file
View File

@ -0,0 +1,845 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Library handle.
*/
#include <stdexcept>
#include <cstdint>
#include "cutlass/library/handle.h"
#include "cutlass/library/singleton.h"
#include "cutlass/library/util.h"
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructor
Handle::Handle(
cudaStream_t stream,
size_t workspace_size
):
stream_(stream),
workspace_(nullptr),
workspace_size_(0),
scalar_pointer_mode_(ScalarPointerMode::kHost),
last_operation_(nullptr) {
int device_idx = -1;
cudaError_t error = cudaGetDevice(&device_idx);
if (error != cudaSuccess) {
throw std::runtime_error("cudaGetDevice() failed");
}
error = cudaGetDeviceProperties(&device_, device_idx);
if (error != cudaSuccess) {
throw std::runtime_error("cudaGetDeviceProperties() failed");
}
set_workspace_size(workspace_size);
Singleton::get();
}
/// Destructor
Handle::~Handle() {
if (workspace_) {
if (workspace_) {
cudaFree(workspace_);
}
workspace_ = nullptr;
workspace_size_ = 0;
}
}
/// Move constructor
Handle::Handle(Handle && handle) {
device_ = handle.device_;
workspace_size_ = handle.workspace_size_;
workspace_ = handle.workspace_;
stream_ = handle.stream_;
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
handle.workspace_ = nullptr;
handle.workspace_size_ = 0;
}
/// Move assignment operator
Handle & Handle::operator=(Handle && handle) {
device_ = handle.device_;
workspace_size_ = handle.workspace_size_;
workspace_ = handle.workspace_;
stream_ = handle.stream_;
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
handle.workspace_ = nullptr;
handle.workspace_size_ = 0;
return *this;
}
int Handle::compute_capability() const {
return device_.major * 10 + device_.minor;
}
/// Sets the current CUDA stream
void Handle::set_stream(cudaStream_t stream) {
stream_ = stream;
}
/// Gets the current CUDA stream
cudaStream_t Handle::get_stream() const {
return stream_;
}
/// Gets the device workspace size
size_t Handle::get_workspace_size() const {
return workspace_size_;
}
/// Gets a pointer to the device workspace allocation in Global Memory
void *Handle::get_workspace() const {
return workspace_;
}
/// Sets the size of device workspace, invalidating previous calls to get_device_workspace()
void Handle::set_workspace_size(size_t bytes) {
if (bytes != workspace_size_) {
if (workspace_) {
cudaFree(workspace_);
}
workspace_ = nullptr;
workspace_size_ = bytes;
if (workspace_size_) {
cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_);
if (error != cudaSuccess) {
throw std::runtime_error("Failed to allocate workspace");
}
}
}
if (workspace_) {
cudaError_t error = cudaMemset(workspace_, 0, workspace_size_);
if (error != cudaSuccess) {
throw std::runtime_error("Failed to clear workspace");
}
}
}
/// Gets the scalar pointer mode
ScalarPointerMode Handle::get_scalar_pointer_mode() const {
return scalar_pointer_mode_;
}
/// Sets the scalar pointer mode
void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) {
scalar_pointer_mode_ = mode;
}
/// Gets the last operation
Operation const *Handle::get_last_operation() const {
return last_operation_;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns the maximum required alignment for each operator
static int maximum_alignment_requirement(GemmDescription const &desc) {
return std::max(
std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment);
}
/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a
/// given upper limit.
static int gemm_problem_alignment(
int M,
int N,
int K,
NumericTypeID element_A,
void const *ptr_A,
int lda,
int64_t batch_stride_A,
NumericTypeID element_B,
void const *ptr_B,
int ldb,
int64_t batch_stride_B,
NumericTypeID element_C,
void const * ptr_C,
int ldc,
int64_t batch_stride_C,
void const * ptr_D,
int ldd,
int64_t batch_stride_D,
int max_alignment_in_bytes = 16
) {
void const *pointers[] = {
ptr_A, ptr_B, ptr_C, ptr_D
};
int64_t extents[] = {
M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D
};
NumericTypeID elements[] = {
element_A, element_B, element_C
};
for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) {
bool satisfied = true;
// Can pointers satisfy this?
for (void const *ptr : pointers) {
std::uintptr_t int_ptr = reinterpret_cast<std::uintptr_t>(ptr);
if (int_ptr % max_alignment_in_bytes) {
satisfied = false;
break;
}
}
if (!satisfied) {
continue;
}
// Compute the maximum alignment based on element data types
int max_element_alignment = 0;
for (NumericTypeID type_id : elements) {
int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id);
max_element_alignment = std::max(max_element_alignment, element_alignment);
}
// Can the problem size and leading dimensions satisfy this?
for (int64_t extent : extents) {
if (extent % max_element_alignment) {
satisfied = false;
break;
}
}
if (!satisfied) {
continue;
}
// Yes
return max_element_alignment;
}
// No alignment satisfies this problem
return 0;
}
/// Find the best kernel in descending order of preference.
static Operation const * find_gemm_operation(
GemmOperationFunctionalMap::const_iterator operators_it,
GemmPreferenceKey const preference_key) {
auto cc_it = operators_it->second.upper_bound(preference_key);
if (cc_it == operators_it->second.begin()) {
return nullptr;
}
Operation const *operation = nullptr;
// Search in descending order of compute capability
do {
--cc_it;
// Search tile sizes in order, for now.
for (auto const * op : cc_it->second) {
GemmDescription const &desc = static_cast<GemmDescription const &>(op->description());
int min_cc = desc.tile_description.minimum_compute_capability;
int max_cc = desc.tile_description.maximum_compute_capability;
int op_alignment = maximum_alignment_requirement(desc);
if ((min_cc <= preference_key.compute_capability) &&
(preference_key.compute_capability <= max_cc) &&
(op_alignment <= preference_key.alignment)) {
operation = op;
break;
}
}
} while (!operation && cc_it != operators_it->second.begin());
return operation;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Executes a GEMM computation: D <= alpha * A*B + beta * C
Status Handle::gemm(
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
void const * ptr_A, /// Pointer to A matrix in Global Memory
int lda, /// Leading dimension of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
void const * ptr_B, /// Pointer to B matrix in Global Memory
int ldb, /// Leading dimension of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrices
void const * ptr_C, /// Pointer to C matrix
int ldc, /// Leading dimension of C matrix
void * ptr_D, /// Pointer to D matrix
int ldd /// Leading dimension of D matrix
) {
//
// Find the operation
//
GemmFunctionalKey key(
element_compute,
element_scalar,
element_A,
layout_A,
transform_A,
element_B,
layout_B,
transform_B,
element_C
);
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
return cutlass::Status::kErrorNotSupported;
}
if (operators_it->second.empty()) {
return cutlass::Status::kErrorNotSupported;
}
//
// Compute the largest alignment restriction the kernel can satisfy.
//
// Maximum alignment expectation among all kernels (in units of bytes)
int const kMaximumAlignmentSize = 16;
int alignment = gemm_problem_alignment(
M, N, K,
element_A, ptr_A, lda, 0,
element_B, ptr_B, ldb, 0,
element_C, ptr_C, ldc, 0,
ptr_D, ldd, 0, kMaximumAlignmentSize
);
//
// Find the best kernel in descending order of preference.
//
GemmPreferenceKey preference_key(compute_capability(), alignment);
Operation const *operation = find_gemm_operation(operators_it, preference_key);
if (!operation) {
return cutlass::Status::kErrorNotSupported;
}
last_operation_ = operation;
//
// Configure operation
//
GemmConfiguration configuration{
{M, N, K},
lda,
ldb,
ldc,
ldd,
1
};
// Query host work space size
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
char host_workspace[kHostWorkspaceSize];
// Query device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
// Initialize host and device workspaces
Status status = operation->initialize(
&configuration,
host_workspace,
workspace_,
stream_);
if (status != cutlass::Status::kSuccess) {
return status;
}
// Run the operator
GemmArguments arguments{
ptr_A,
ptr_B,
ptr_C,
ptr_D,
alpha,
beta,
scalar_pointer_mode_
};
return operation->run(&arguments, host_workspace, workspace_, stream_);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Planar complex GEMM
Status Handle::gemm_planar_complex(
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix
void const * ptr_A_real, /// Pointer to real part of A matrix
void const * ptr_A_imag, /// Pointer to imaginary part of A matrix
int lda_real, /// Leading dimension of real part of A matrix
int lda_imag, /// Leading dimension of imaginary part of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix
void const * ptr_B_real, /// Pointer to real part of B matrix
void const * ptr_B_imag, /// Pointer to imaginary part of B matrix
int ldb_real, /// Leading dimension of real part of B matrix
int ldb_imag, /// Leading dimension of imaginary part of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrix
void const * ptr_C_real, /// Pointer to real part of C matrix
void const * ptr_C_imag, /// Pointer to imaginary part of C matrix
int ldc_real, /// Leading dimension of real part of C matrix
int ldc_imag, /// Leading dimension of imaginary part of C matrix
void * ptr_D_real, /// Pointer to real part of D matrix
void * ptr_D_imag, /// Pointer to imaginary part of D matrix
int ldd_real, /// Leading dimension of real part of D matrix
int ldd_imag, /// Leading dimension of imaginary part of D matrix
int batch_count, /// Number of batched GEMMs to execute
int64_t batch_stride_A_real,
int64_t batch_stride_A_imag,
int64_t batch_stride_B_real,
int64_t batch_stride_B_imag,
int64_t batch_stride_C_real,
int64_t batch_stride_C_imag,
int64_t batch_stride_D_real,
int64_t batch_stride_D_imag
) {
//
// Find the operation
//
GemmFunctionalKey key(
element_compute,
element_scalar,
element_A,
layout_A,
transform_A,
element_B,
layout_B,
transform_B,
element_C
);
auto operators_it = Singleton::get().operation_table.gemm_planar_complex_operations.find(key);
if (operators_it == Singleton::get().operation_table.gemm_planar_complex_operations.end()) {
return cutlass::Status::kErrorNotSupported;
}
if (operators_it->second.empty()) {
return cutlass::Status::kErrorNotSupported;
}
//
// Compute the largest alignment restriction the kernel can satisfy.
//
// Maximum alignment expectation among all kernels (in units of bytes)
int const kMaximumAlignmentSize = 16;
int alignment = std::max(
gemm_problem_alignment(
M, N, K,
element_A, ptr_A_real, lda_real, batch_stride_A_real,
element_B, ptr_B_real, ldb_real, batch_stride_B_real,
element_C, ptr_C_real, ldc_real, batch_stride_C_real,
ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize
),
gemm_problem_alignment(
M, N, K,
element_A, ptr_A_imag, lda_imag, batch_stride_A_imag,
element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag,
element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag,
ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize
)
);
//
// Find the best kernel in descending order of preference.
//
GemmPreferenceKey preference_key(compute_capability(), alignment);
Operation const *operation = find_gemm_operation(operators_it, preference_key);
if (!operation) {
return cutlass::Status::kErrorNotSupported;
}
last_operation_ = operation;
//
// Configure operation
//
GemmPlanarComplexConfiguration configuration{
GemmUniversalMode::kBatched,
{M, N, K},
batch_count,
lda_real,
lda_imag,
ldb_real,
ldb_imag,
ldc_real,
ldc_imag,
ldd_real,
ldd_imag
};
// Query host work space size
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
char host_workspace[kHostWorkspaceSize];
// Query device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
// Initialize host and device workspaces
Status status = operation->initialize(
&configuration,
host_workspace,
workspace_,
stream_);
if (status != cutlass::Status::kSuccess) {
return status;
}
// Run the operator
GemmPlanarComplexArguments arguments{
ptr_A_real,
ptr_A_imag,
ptr_B_real,
ptr_B_imag,
ptr_C_real,
ptr_C_imag,
ptr_D_real,
ptr_D_imag,
alpha,
beta,
scalar_pointer_mode_,
batch_stride_A_real,
batch_stride_A_imag,
batch_stride_B_real,
batch_stride_B_imag,
batch_stride_C_real,
batch_stride_C_imag,
batch_stride_D_real,
batch_stride_D_imag
};
return operation->run(&arguments, host_workspace, workspace_, stream_);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Planar complex batched GEMM loading pointers from arrays in global memory
Status Handle::gemm_planar_complex_array(
int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid)
int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid)
int expected_K, /// Expected GEMM K dimension
int batch_count, /// Number of independent GEMM computations to execute
int const *M, /// Array containing the GEMM M dimension for each batch index
int const *N, /// Array containing the GEMM N dimension for each batch index
int const *K, /// Array containing the GEMM K dimension for each batch index
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
void const *alpha, /// Pointer to alpha scalar
NumericTypeID element_A, /// Data type of A matrix elements
LayoutTypeID layout_A, /// Layout of A matrix
ComplexTransform transform_A, /// Complex transformation applied to A matrix
void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices
void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices
int lda_real, /// Leading dimension of real part of A matrix
int lda_imag, /// Leading dimension of imaginary part of A matrix
NumericTypeID element_B, /// Data type of B matrix elements
LayoutTypeID layout_B, /// Layout of B matrix
ComplexTransform transform_B, /// Complex transformation applied to B matrix
void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices
void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices
int ldb_real, /// Leading dimension of real part of B matrix
int ldb_imag, /// Leading dimension of imaginary part of B matrix
void const * beta, /// Pointer to beta scalar
NumericTypeID element_C, /// Data type of C and D matrix
void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices
void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices
int ldc_real, /// Leading dimension of real part of C matrix
int ldc_imag, /// Leading dimension of imaginary part of C matrix
void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices
void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices
int ldd_real, /// Leading dimension of real part of D matrix
int ldd_imag /// Leading dimension of imaginary part of D matrix
) {
//
// Find the operation
//
GemmFunctionalKey key(
element_compute,
element_scalar,
element_A,
layout_A,
transform_A,
element_B,
layout_B,
transform_B,
element_C
);
auto operators_it = Singleton::get().operation_table.gemm_planar_complex_array_operations.find(key);
if (operators_it == Singleton::get().operation_table.gemm_planar_complex_array_operations.end()) {
return cutlass::Status::kErrorNotSupported;
}
if (operators_it->second.empty()) {
return cutlass::Status::kErrorNotSupported;
}
//
// Compute the largest alignment restriction the kernel can satisfy.
//
// Maximum alignment expectation among all kernels (in units of bytes)
int const kMaximumAlignmentSize = 16;
int alignment = std::max(
gemm_problem_alignment(
expected_M, expected_N, expected_K,
element_A, nullptr, lda_real, 0,
element_B, nullptr, ldb_real, 0,
element_C, nullptr, ldc_real, 0,
nullptr, ldd_real, 0, kMaximumAlignmentSize
),
gemm_problem_alignment(
expected_M, expected_N, expected_K,
element_A, nullptr, lda_imag, 0,
element_B, nullptr, ldb_imag, 0,
element_C, nullptr, ldc_imag, 0,
nullptr, ldd_imag, 0, kMaximumAlignmentSize
)
);
//
// Find the best kernel in descending order of preference.
//
GemmPreferenceKey preference_key(compute_capability(), alignment);
Operation const *operation = find_gemm_operation(operators_it, preference_key);
if (!operation) {
return cutlass::Status::kErrorNotSupported;
}
last_operation_ = operation;
//
// Configure operation
//
GemmPlanarComplexArrayConfiguration configuration{
{expected_M, expected_N, expected_K},
batch_count,
lda_real,
lda_imag,
ldb_real,
ldb_imag,
ldc_real,
ldc_imag,
ldd_real,
ldd_imag
};
// Query host work space size
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
char host_workspace[kHostWorkspaceSize];
// Query device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
}
// Initialize host and device workspaces
Status status = operation->initialize(
&configuration,
host_workspace,
workspace_,
stream_);
if (status != cutlass::Status::kSuccess) {
return status;
}
// Run the operator
GemmPlanarComplexArrayArguments arguments{
M, N, K,
ptr_A_real,
ptr_A_imag,
ptr_B_real,
ptr_B_imag,
ptr_C_real,
ptr_C_imag,
ptr_D_real,
ptr_D_imag,
alpha,
beta,
scalar_pointer_mode_
};
return operation->run(&arguments, host_workspace, workspace_, stream_);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -57,6 +57,10 @@ namespace library {
template <typename T> struct NumericTypeMap;
template <> struct NumericTypeMap<cutlass::uint1b_t> {
static NumericTypeID const kId = NumericTypeID::kB1;
};
template <> struct NumericTypeMap<cutlass::int4b_t> {
static NumericTypeID const kId = NumericTypeID::kS4;
};
@ -123,6 +127,28 @@ template <> struct NumericTypeMap<cutlass::complex<double> > {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> struct MathOperationMap {
static MathOperationID const kId = MathOperationID::kInvalid;
};
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAdd> {
static MathOperationID const kId = MathOperationID::kMultiplyAdd;
};
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> {
static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate;
};
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> {
static MathOperationID const kId = MathOperationID::kMultiplyAddComplex;
};
template <> struct MathOperationMap<cutlass::arch::OpXorPopc> {
static MathOperationID const kId = MathOperationID::kXorPopc;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> struct LayoutMap;
template <> struct LayoutMap<cutlass::layout::ColumnMajor> {
@ -133,6 +159,34 @@ template <> struct LayoutMap<cutlass::layout::RowMajor> {
static LayoutTypeID const kId = LayoutTypeID::kRowMajor;
};
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> {
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16;
};
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> {
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16;
};
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> {
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32;
};
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> {
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32;
};
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> {
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64;
};
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> {
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64;
};
template <> struct LayoutMap<cutlass::layout::TensorNHWC> {
static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> struct OpcodeClassMap;
@ -148,6 +202,19 @@ template <> struct OpcodeClassMap<arch::OpClassTensorOp> {
template <> struct OpcodeClassMap<arch::OpClassWmmaTensorOp> {
static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <cutlass::ComplexTransform Transform> struct ComplexTransformMap;
template <> struct ComplexTransformMap<cutlass::ComplexTransform::kNone> {
static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone;
};
template <> struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> {
static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> struct ArchMap;

View File

@ -1,6 +1,4 @@
/*!
*//***************************************************************************************************
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
@ -37,11 +35,12 @@
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_all(Manifest &manifest);
// init and insert all cutlass op in manifest object (procedurally generated using generator.py)
void initialize_all(Manifest &manifest);
///////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////////////
/// Top-level initialization
Status Manifest::initialize() {
@ -50,7 +49,13 @@ Status Manifest::initialize() {
operations_.clear();
}
initialize_all(*this);
switch(provider_) {
case Provider::kCUTLASS:
initialize_all(*this); break;
default:
break;
}
return Status::kSuccess;
}

View File

@ -0,0 +1,159 @@
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
\file
\brief Defines a data structure in which a set of functionally equivalent library::Operation
instances may be queried.
*/
#include <fstream>
#include "cutlass/library/library.h"
#include "cutlass/library/operation_table.h"
#include "cutlass/library/util.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) {
out << "{\n"
<< " element_compute: " << to_string(k.element_compute) << "\n"
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
<< " element_A: " << to_string(k.element_A) << "\n"
<< " layout_A: " << to_string(k.layout_A) << "\n"
<< " transform_A: " << to_string(k.transform_A) << "\n"
<< " element_B: " << to_string(k.element_B) << "\n"
<< " layout_B: " << to_string(k.layout_B) << "\n"
<< " transform_B: " << to_string(k.transform_B) << "\n"
<< " element_C: " << to_string(k.element_C) << "\n"
<< "}";
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
void OperationTable::append(Manifest const &manifest) {
// Insert operations into appropriate data structure
for (auto const & operation : manifest) {
OperationDescription const &desc = operation->description();
if (desc.kind == OperationKind::kGemm) {
GemmDescription const &gemm_desc = static_cast<GemmDescription const &>(desc);
if (gemm_desc.gemm_kind == GemmKind::kGemm) {
GemmFunctionalKey functional_key(
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.transform_A,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.transform_B,
gemm_desc.C.element
);
Operation const *op = operation.get();
int cc = gemm_desc.tile_description.minimum_compute_capability;
int alignment = std::max(std::max(
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
GemmPreferenceKey preference_key(cc, alignment);
gemm_operations[functional_key][preference_key].push_back(op);
}
else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplex) {
GemmFunctionalKey functional_key(
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.transform_A,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.transform_B,
gemm_desc.C.element
);
Operation const *op = operation.get();
int cc = gemm_desc.tile_description.minimum_compute_capability;
int alignment = std::max(std::max(
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
GemmPreferenceKey preference_key(cc, alignment);
gemm_planar_complex_operations[functional_key][preference_key].push_back(op);
}
else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplexArray) {
GemmFunctionalKey functional_key(
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.transform_A,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.transform_B,
gemm_desc.C.element
);
Operation const *op = operation.get();
int cc = gemm_desc.tile_description.minimum_compute_capability;
int alignment = std::max(std::max(
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
GemmPreferenceKey preference_key(cc, alignment);
gemm_planar_complex_array_operations[functional_key][preference_key].push_back(op);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,63 @@
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <memory>
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/operation_table.h"
#include "cutlass/library/singleton.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
static std::unique_ptr<Singleton> instance;
/////////////////////////////////////////////////////////////////////////////////////////////////
Singleton::Singleton() {
manifest.initialize();
operation_table.append(manifest);
}
Singleton const & Singleton::get() {
if (!instance.get()) {
instance.reset(new Singleton);
}
return *instance.get();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -25,17 +25,65 @@
#include <iosfwd>
#include <complex>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/complex.h"
#include "cutlass/library/library.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
char const *text;
char const *pretty;
Provider enumerant;
}
Provider_enumerants[] = {
{"cutlass", "CUTLASS", Provider::kCUTLASS},
{"host", "reference_host", Provider::kReferenceHost},
{"device", "reference_device", Provider::kReferenceDevice},
{"cublas", "cuBLAS", Provider::kCUBLAS},
};
/// Converts a Provider enumerant to a string
char const *to_string(Provider provider, bool pretty) {
for (auto const & possible : Provider_enumerants) {
if (provider == possible.enumerant) {
if (pretty) {
return possible.pretty;
}
else {
return possible.text;
}
}
}
return pretty ? "Invalid" : "invalid";
}
/// Parses a Provider enumerant from a string
template <>
Provider from_string<Provider>(std::string const &str) {
for (auto const & possible : Provider_enumerants) {
if ((str.compare(possible.text) == 0) ||
(str.compare(possible.pretty) == 0)) {
return possible.enumerant;
}
}
return Provider::kInvalid;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
@ -44,7 +92,7 @@ static struct {
OperationKind enumerant;
}
OperationKind_enumerants[] = {
{"gemm", "Gemm", OperationKind::kGemm},
{"gemm", "Gemm", OperationKind::kGemm},
};
/// Converts a Status enumerant to a string
@ -203,6 +251,9 @@ int sizeof_bits(NumericTypeID type) {
case NumericTypeID::kF16: return 16;
case NumericTypeID::kF32: return 32;
case NumericTypeID::kF64: return 64;
case NumericTypeID::kCF16: return 32;
case NumericTypeID::kCF32: return 64;
case NumericTypeID::kCF64: return 128;
case NumericTypeID::kS4: return 4;
case NumericTypeID::kS8: return 8;
case NumericTypeID::kS16: return 16;
@ -291,6 +342,9 @@ bool is_float_type(NumericTypeID type) {
case NumericTypeID::kF16: return true;
case NumericTypeID::kF32: return true;
case NumericTypeID::kF64: return true;
case NumericTypeID::kCF16: return true;
case NumericTypeID::kCF32: return true;
case NumericTypeID::kCF64: return true;
default: break;
}
return false;
@ -309,8 +363,18 @@ layout_aliases[] = {
{LayoutTypeID::kColumnMajor, "column"},
{LayoutTypeID::kColumnMajor, "col"},
{LayoutTypeID::kColumnMajor, "n"},
{LayoutTypeID::kColumnMajorInterleavedK16, "nk16"},
{LayoutTypeID::kRowMajorInterleavedK16, "tk16"},
{LayoutTypeID::kColumnMajorInterleavedK32, "nk32"},
{LayoutTypeID::kRowMajorInterleavedK32, "tk32"},
{LayoutTypeID::kColumnMajorInterleavedK64, "nk64"},
{LayoutTypeID::kRowMajorInterleavedK64, "tk64"},
{LayoutTypeID::kTensorNCHW, "nchw"},
{LayoutTypeID::kTensorNHWC, "packed_nhwc"},
{LayoutTypeID::kTensorNHWC, "nhwc"},
{LayoutTypeID::kUnknown, "*"},
{LayoutTypeID::kInvalid, nullptr}
};
@ -344,7 +408,12 @@ int get_layout_stride_rank(LayoutTypeID layout_id) {
case LayoutTypeID::kColumnMajorInterleavedK4:
case LayoutTypeID::kRowMajorInterleavedK4:
case LayoutTypeID::kColumnMajorInterleavedK16:
case LayoutTypeID::kRowMajorInterleavedK16: return 1;
case LayoutTypeID::kRowMajorInterleavedK16:
case LayoutTypeID::kColumnMajorInterleavedK32:
case LayoutTypeID::kRowMajorInterleavedK32:
case LayoutTypeID::kColumnMajorInterleavedK64:
case LayoutTypeID::kRowMajorInterleavedK64:
return 1;
case LayoutTypeID::kTensorNCHW:
case LayoutTypeID::kTensorNHWC: return 3;
default : throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank");
@ -396,8 +465,51 @@ OpcodeClassID from_string<OpcodeClassID>(std::string const &str) {
return OpcodeClassID::kInvalid;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
char const *text;
char const *pretty;
ComplexTransform enumerant;
}
ComplexTransform_enumerants[] = {
{"n", "none", ComplexTransform::kNone},
{"c", "conj", ComplexTransform::kConjugate}
};
/// Converts a ComplexTransform enumerant to a string
char const *to_string(ComplexTransform type, bool pretty) {
for (auto const & possible : ComplexTransform_enumerants) {
if (type == possible.enumerant) {
if (pretty) {
return possible.pretty;
}
else {
return possible.text;
}
}
}
return pretty ? "Invalid" : "invalid";
}
/// Converts a ComplexTransform enumerant from a string
template <>
ComplexTransform from_string<ComplexTransform>(std::string const &str) {
for (auto const & possible : ComplexTransform_enumerants) {
if ((str.compare(possible.text) == 0) ||
(str.compare(possible.pretty) == 0)) {
return possible.enumerant;
}
}
return ComplexTransform::kInvalid;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str) {
int size_bytes = sizeof_bits(type) / 8;
@ -574,25 +686,36 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
break;
case NumericTypeID::kCF16:
{
std::complex<float> tmp;
cutlass::complex<half_t> const *x =
reinterpret_cast<cutlass::complex<half_t> const *>(bytes.data());
tmp.real(x->real());
tmp.imag(x->imag());
ss << float(x->real());
ss << tmp;
if (x->imag() != cutlass::half_t()) {
ss << "+i" << float(x->imag());
}
}
break;
case NumericTypeID::kCF32:
{
ss << *reinterpret_cast<std::complex<float>*>(bytes.data());
cutlass::complex<float> const * x = reinterpret_cast<cutlass::complex<float> const *>(bytes.data());
ss << x->real();
if (x->imag() != float()) {
ss << "+i" << x->imag();
}
}
break;
case NumericTypeID::kCF64:
{
ss << *reinterpret_cast<std::complex<double>*>(bytes.data());
cutlass::complex<double> const * x = reinterpret_cast<cutlass::complex<double> const *>(bytes.data());
ss << x->real();
if (x->imag() != double()) {
ss << "+i" << x->imag();
}
}
break;
default:

View File

@ -33,7 +33,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
src/gpu_timer.cpp
src/device_allocation.cu
src/device_context.cu
src/cublas_helpers.cpp
src/cublas_helpers.cpp
src/problem_space.cpp
src/operation_profiler.cu
src/gemm_operation_profiler.cu
@ -54,11 +54,11 @@ set_target_properties(cutlass_profiler PROPERTIES EXPORT_NAME profiler)
# Include paths
#
target_include_directories(cutlass_profiler
target_include_directories(
cutlass_profiler
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src # Source directory
../../tools/util/include
)
)
#
# Library dependencies
@ -68,8 +68,8 @@ target_link_libraries(
cutlass_profiler
PRIVATE
cutlass_lib
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:cublas>
gtest
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
cudart
)

View File

@ -39,14 +39,14 @@ namespace profiler {
/// Converts a cuBLAS status to cutlass::Status
Status get_cutlass_status(cublasStatus_t cublas) {
if (cublas == CUBLAS_STATUS_SUCCESS) {
return Status::kSuccess;
}
else if (cublas == CUBLAS_STATUS_INVALID_VALUE) {
return Status::kErrorInvalidProblem;
}
if (cublas == CUBLAS_STATUS_NOT_SUPPORTED) {
return Status::kErrorNotSupported;
switch (cublas) {
case CUBLAS_STATUS_SUCCESS:
return Status::kSuccess;
case CUBLAS_STATUS_INVALID_VALUE:
return Status::kErrorInvalidProblem;
case CUBLAS_STATUS_NOT_SUPPORTED:
return Status::kErrorNotSupported;
default: break;
}
return Status::kErrorInternal;
}
@ -145,6 +145,13 @@ Status cublas_satisfies(library::GemmDescription const &desc) {
return Status::kErrorNotSupported;
}
// output type S4 and S8 not supported in cuBLAS
if (desc.C.element == library::NumericTypeID::kS4 ||
desc.C.element == library::NumericTypeID::kS8) {
return Status::kErrorNotSupported;
}
return Status::kSuccess;
}

View File

@ -33,7 +33,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "options.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -86,6 +86,161 @@ public:
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Selects one or more cuBLAS algorithms.
static void select_cublas_algorithms(
std::vector<cublasGemmAlgo_t> &algorithms,
Options const &options,
library::GemmDescription const &op_desc) {
library::OpcodeClassID const & opcode_class =
op_desc.tile_description.math_instruction.opcode_class;
switch (options.library.algorithm_mode) {
case AlgorithmMode::kMatching:
{
algorithms.push_back(get_cublas_gemm_algo(
op_desc.tile_description.threadblock_shape.m(),
op_desc.tile_description.threadblock_shape.n(),
op_desc.tile_description.threadblock_shape.k(),
opcode_class));
break;
}
case AlgorithmMode::kBest:
{
// Choose first enumerated mode. If none are enumerated, choose based on opcode class
// and evaluate all of them.
if (options.library.algorithms.empty()) {
// Enumerate all algorithms
if (opcode_class == library::OpcodeClassID::kSimt) {
for (int algo = CUBLAS_GEMM_DEFAULT;
algo <= CUBLAS_GEMM_ALGO23;
++algo) {
algorithms.push_back(cublasGemmAlgo_t(algo));
}
}
else {
for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP;
++algo) {
algorithms.push_back(cublasGemmAlgo_t(algo));
}
}
}
else {
// Use the listed algorithms
algorithms.reserve(options.library.algorithms.size());
for (int algo : options.library.algorithms) {
algorithms.push_back(reinterpret_cast<cublasGemmAlgo_t const &>(algo));
}
}
break;
}
case AlgorithmMode::kDefault:
{
// Use the library's default algorithm
algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ?
CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
break;
}
default:
{
break;
}
}
}
/// Dispatcher to cublasGemmEx()
struct cublasGemmExDispatcher {
//
// Data members
//
library::GemmConfiguration configuration;
library::GemmArguments arguments;
// cublass-specific data structures to fill cublas API call arguments
cublasOperation_t trans_A;
cublasOperation_t trans_B;
cudaDataType_t data_type_A;
cudaDataType_t data_type_B;
cudaDataType_t data_type_C;
cudaDataType_t compute_type;
cublasGemmAlgo_t algo;
Status status;
//
// Methods
//
cublasGemmExDispatcher(
library::GemmDescription const &op_desc,
library::GemmConfiguration configuration_,
library::GemmArguments arguments_,
cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT
):
configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) {
trans_A = get_cublas_transpose_operation(op_desc.A.layout);
trans_B = get_cublas_transpose_operation(op_desc.B.layout);
bool good = true;
good = (good && get_cublas_datatype(data_type_A, op_desc.A.element));
good = (good && get_cublas_datatype(data_type_B, op_desc.B.element));
good = (good && get_cublas_datatype(data_type_C, op_desc.C.element));
good = (good && get_cublas_datatype(
compute_type,
op_desc.tile_description.math_instruction.element_accumulator));
if (!good) {
status = Status::kErrorNotSupported;
}
}
/// Executes GEMM using these arguments
cublasStatus_t operator()(cublasHandle_t handle) {
return cublasGemmEx(
handle,
trans_A,
trans_B,
configuration.problem_size.m(),
configuration.problem_size.n(),
configuration.problem_size.k(),
arguments.alpha,
arguments.A,
data_type_A,
int(configuration.lda),
arguments.B,
data_type_B,
int(configuration.ldb),
arguments.beta,
arguments.D,
data_type_C,
int(configuration.ldc),
compute_type,
algo
);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
} // namespace profiler
} // namespace cutlass

View File

@ -29,14 +29,9 @@
#include <iostream>
#include <stdexcept>
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
// Profiler includes
#include "cutlass_profiler.h"
#include "gemm_operation_profiler.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -49,7 +44,8 @@ CutlassProfiler::CutlassProfiler(
):
options_(options) {
operation_profilers_.emplace_back(new GemmOperationProfiler);
operation_profilers_.emplace_back(new GemmOperationProfiler);
}
CutlassProfiler::~CutlassProfiler() {
@ -112,7 +108,7 @@ void CutlassProfiler::enumerate_() {
/// Profiles all operations
int CutlassProfiler::profile_() {
library::Manifest manifest;
library::Manifest manifest(library::Provider::kCUTLASS);
Status status = manifest.initialize();
if (status != Status::kSuccess) {
@ -165,7 +161,8 @@ void CutlassProfiler::print_usage_(std::ostream &out) {
}
out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n"
<< " $ cutlass_profiler --operation=Gemm --help\n\n";
<< " $ cutlass_profiler --operation=Gemm --help\n\n"
;
}
/// Prints usage

View File

@ -27,6 +27,9 @@
*/
#pragma once
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "options.h"
#include "operation_profiler.h"

View File

@ -30,11 +30,11 @@
#include <iostream>
#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; }
//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; }
//#define report(x) {}
// Enable/Disble Profiler debug prints
#define DEBUG_PROFILER
//#define DEBUG_PROFILER
//RED 31m // profiler prints debug messages in red
//YELLOW 33m // ir prints debug messages in yellow
@ -43,7 +43,7 @@
#define debugprof(...)
#else
#define debugprof(...) do { \
printf("\033[31m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \
printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \
printf(__VA_ARGS__); \
printf("\033[0m\n"); \
} while (0)

View File

@ -34,12 +34,12 @@
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/library/util.h"
#include "device_allocation.h"
namespace cutlass {
@ -106,6 +106,18 @@ std::vector<int> DeviceAllocation::get_packed_layout(
case library::LayoutTypeID::kRowMajorInterleavedK16:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<16>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK32:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<32>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK32:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<32>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK64:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<64>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK64:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<64>>(extent);
break;
case library::LayoutTypeID::kTensorNCHW:
stride = get_packed_layout_stride<cutlass::layout::TensorNCHW>(extent);
break;
@ -200,6 +212,18 @@ size_t DeviceAllocation::construct_layout(
case library::LayoutTypeID::kRowMajorInterleavedK16:
return construct_layout_<cutlass::layout::RowMajorInterleaved<16>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK32:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK32:
return construct_layout_<cutlass::layout::RowMajorInterleaved<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK64:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<64>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK64:
return construct_layout_<cutlass::layout::RowMajorInterleaved<64>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNCHW:
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
@ -415,6 +439,14 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kCF64:
cutlass::reference::device::BlockFillRandom<complex<double>>(
reinterpret_cast<complex<double> *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS8:
cutlass::reference::device::BlockFillRandom<int8_t>(
reinterpret_cast<int8_t *>(pointer_),
@ -508,6 +540,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kCF16:
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::half_t>>(
reinterpret_cast<cutlass::complex<cutlass::half_t> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::host::BlockFillRandom<double>(
reinterpret_cast<double *>(host_data.data()),
@ -516,6 +556,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kCF64:
cutlass::reference::host::BlockFillRandom<cutlass::complex<double>>(
reinterpret_cast<cutlass::complex<double> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS8:
cutlass::reference::host::BlockFillRandom<int8_t>(
reinterpret_cast<int8_t *>(host_data.data()),
@ -607,13 +655,25 @@ bool DeviceAllocation::block_compare_equal(
reinterpret_cast<float const *>(ptr_A),
reinterpret_cast<float const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF16:
return reference::device::BlockCompareEqual<complex<half_t>>(
reinterpret_cast<complex<half_t> const *>(ptr_A),
reinterpret_cast<complex<half_t> const *>(ptr_B),
capacity);
case library::NumericTypeID::kF64:
return reference::device::BlockCompareEqual<double>(
reinterpret_cast<double const *>(ptr_A),
reinterpret_cast<double const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF64:
return reference::device::BlockCompareEqual<complex<double>>(
reinterpret_cast<complex<double> const *>(ptr_A),
reinterpret_cast<complex<double> const *>(ptr_B),
capacity);
case library::NumericTypeID::kS8:
return reference::device::BlockCompareEqual<int8_t>(
reinterpret_cast<int8_t const *>(ptr_A),

View File

@ -74,16 +74,34 @@ DeviceAllocation *DeviceContext::allocate_tensor(
allocate_tensor(name, type, layout_id, extent, stride);
if (options.initialization.enabled) {
Distribution data_distribution = options.initialization.data_distribution;
if (options.initialization.provider == Provider::kReferenceDevice) {
// check if data distribution is allowed to change
if(!options.initialization.fix_data_distribution) {
// change data distribution based on bit width
switch(type) {
case library::NumericTypeID::kB1:
data_distribution.set_uniform(0, 2, 0);
break;
case library::NumericTypeID::kS8:
data_distribution.set_uniform(-2, 2, 0);
break;
case library::NumericTypeID::kU8:
data_distribution.set_uniform(0, 4, 0);
break;
default: break;
}
}
if (options.initialization.provider == library::Provider::kReferenceDevice) {
allocation->initialize_random_device(
options.initialization.seed,
options.initialization.data_distribution);
data_distribution);
}
else if (options.initialization.provider == Provider::kReferenceHost) {
else if (options.initialization.provider == library::Provider::kReferenceHost) {
allocation->initialize_random_host(
options.initialization.seed,
options.initialization.data_distribution);
data_distribution);
}
}

View File

@ -123,53 +123,6 @@ AlgorithmMode from_string<AlgorithmMode>(std::string const &str) {
/////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
char const *text;
char const *pretty;
Provider enumerant;
}
Provider_enumerants[] = {
{"cutlass", "CUTLASS", Provider::kCUTLASS},
{"host", "reference_host", Provider::kReferenceHost},
{"device", "reference_device", Provider::kReferenceDevice},
{"cublas", "cuBLAS", Provider::kCUBLAS},
};
/// Converts a Provider enumerant to a string
char const *to_string(Provider provider, bool pretty) {
for (auto const & possible : Provider_enumerants) {
if (provider == possible.enumerant) {
if (pretty) {
return possible.pretty;
}
else {
return possible.text;
}
}
}
return pretty ? "Invalid" : "invalid";
}
/// Parses a Provider enumerant from a string
template <>
Provider from_string<Provider>(std::string const &str) {
for (auto const & possible : Provider_enumerants) {
if ((str.compare(possible.text) == 0) ||
(str.compare(possible.pretty) == 0)) {
return possible.enumerant;
}
}
return Provider::kInvalid;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
char const *text;
char const *pretty;
@ -180,6 +133,7 @@ Disposition_enumerants[] = {
{"failed", "Failed", Disposition::kFailed},
{"not_run", "Not run", Disposition::kNotRun},
{"not_verified", "Not verified", Disposition::kNotVerified},
{"invalid_problem", "Invalid problem", Disposition::kInvalidProblem},
{"not_supported", "Not supported", Disposition::kNotSupported},
{"incorrect", "Incorrect", Disposition::kIncorrect}
};

View File

@ -30,7 +30,9 @@
#include <string>
#include <vector>
#include <map>
#include <iostream>
#include "cutlass/library/library.h"
#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; }
@ -79,26 +81,6 @@ AlgorithmMode from_string<AlgorithmMode>(std::string const &str);
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Providers
enum class Provider {
kCUTLASS,
kReferenceHost,
kReferenceDevice,
kCUBLAS,
kInvalid
};
using ProviderVector = std::vector<Provider>;
/// Converts a Provider enumerant to a string
char const *to_string(Provider provider, bool pretty = false);
/// Parses a Provider enumerant from a string
template <>
Provider from_string<Provider>(std::string const &str);
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Outcome of a performance test
enum class Disposition {
kPassed,
@ -106,12 +88,13 @@ enum class Disposition {
kNotRun,
kIncorrect,
kNotVerified,
kInvalidProblem,
kNotSupported,
kInvalid
};
/// Converts a Disposition enumerant to a string
char const *to_string(Disposition provider, bool pretty = false);
char const *to_string(Disposition disposition, bool pretty = false);
/// Parses a Disposition enumerant from a string
template <>
@ -159,6 +142,21 @@ char const *to_string(ArgumentTypeID type, bool pretty = false);
template <>
ArgumentTypeID from_string<ArgumentTypeID>(std::string const &str);
/////////////////////////////////////////////////////////////////////////////////////////////////
// Profiler typedefs
using ProviderVector = std::vector<library::Provider>;
using DispositionMap = std::map<library::Provider, Disposition>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Print vector for the report
template <typename T>
std::ostream& operator<< (std::ostream& out, const std::vector<T>& v) {
for(int i = 0; i < v.size(); ++i) {
out << to_string(v[i], true) << (i+1 != v.size() ? "," : "");
}
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler

View File

@ -140,7 +140,6 @@ Status GemmOperationProfiler::initialize_configuration(
return Status::kErrorInvalidProblem;
}
if (!arg_as_int(problem_.m, "m", problem_space, problem)) {
// default value
problem_.m = 1024;
@ -201,7 +200,7 @@ Status GemmOperationProfiler::initialize_configuration(
return Status::kErrorInternal;
}
}
problem_.lda = DeviceAllocation::get_packed_layout(
operation_desc.A.layout, {int(problem_.m), int(problem_.k)}).front();
@ -240,7 +239,7 @@ void GemmOperationProfiler::initialize_result_(
library::GemmDescription const &operation_desc,
ProblemSpace const &problem_space) {
result.provider = Provider::kCUTLASS;
result.provider = library::Provider::kCUTLASS;
result.disposition = Disposition::kNotRun;
result.status = Status::kSuccess;
result.operation_name = operation_desc.name;
@ -277,9 +276,17 @@ void GemmOperationProfiler::initialize_result_(
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n * 2;
result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n);
result.runtime = 0;
// complex-valued support
switch (operation_desc.tile_description.math_instruction.math_operation) {
case library::MathOperationID::kMultiplyAddComplex:
result.flops *= 4;
break;
default: break;
}
}
/// Initializes workspace
@ -290,7 +297,7 @@ Status GemmOperationProfiler::initialize_workspace(
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
library::GemmDescription const &operation_desc =
static_cast<library::GemmDescription const &>(operation->description());
@ -348,7 +355,7 @@ Status GemmOperationProfiler::initialize_workspace(
//
Status status = Status::kSuccess;
if (options.profiling.provider_enabled(Provider::kCUTLASS)) {
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
if (options.execution_mode != ExecutionMode::kDryRun) {
@ -368,8 +375,12 @@ Status GemmOperationProfiler::initialize_workspace(
// If CUTLASS is enabled, generate a result for it
//
results_.push_back(model_result_);
results_.back().provider = Provider::kCUTLASS;
results_.back().provider = library::Provider::kCUTLASS;
results_.back().op_kind = library::OperationKind::kGemm;
results_.back().disposition = Disposition::kNotRun;
for(auto &verification_provider : options.verification.providers) {
results_.back().verification_map[verification_provider] = Disposition::kNotRun;
}
}
return status;
@ -386,7 +397,7 @@ bool GemmOperationProfiler::verify_cutlass(
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
if (!options.profiling.provider_enabled(Provider::kCUTLASS)) {
if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
return true;
}
@ -423,198 +434,62 @@ bool GemmOperationProfiler::verify_cutlass(
return false;
}
// CUTLASS op ran the but not yet verified against any verification provider
results_.back().disposition = Disposition::kNotVerified;
//
// Run verification providers
//
if (options.verification.enabled) {
#if CUTLASS_ENABLE_CUBLAS
if (options.verification.provider_enabled(Provider::kCUBLAS)) {
if (options.verification.provider_enabled(library::Provider::kCUBLAS)) {
// Guard against unsupported cases
auto const & gemm_desc = static_cast<library::GemmDescription const &>(operation->description());
if (cublas_satisfies(gemm_desc) != Status::kSuccess) {
return true;
}
if (cublas_satisfies(gemm_desc) == Status::kSuccess) {
return verify_with_cublas_(
options,
report,
device_context,
operation,
problem_space,
problem);
// call cublas verification if supported
verify_with_cublas_(
options,
report,
device_context,
operation,
problem_space,
problem);
}
else {
// set verification map for cublas to not supported
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported;
}
}
#endif // #if CUTLASS_ENABLE_CUBLAS
// Update disposition to worst case verification outcome among all
// verification providers which are supported
bool is_any_verification_run_passed = false;
for(auto &m : results_.back().verification_map) {
if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) {
results_.back().disposition = m.second;
return true;
}
if(!is_any_verification_run_passed && m.second == Disposition::kPassed) {
is_any_verification_run_passed = true;
}
}
if(is_any_verification_run_passed) {
results_.back().disposition = Disposition::kPassed;
}
}
// Return true means continue profiling
return true;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#if CUTLASS_ENABLE_CUBLAS
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Selects one or more cuBLAS algorithms.
static void select_cublas_algorithms(
std::vector<cublasGemmAlgo_t> &algorithms,
Options const &options,
library::GemmDescription const &op_desc) {
library::OpcodeClassID const & opcode_class =
op_desc.tile_description.math_instruction.opcode_class;
switch (options.library.algorithm_mode) {
case AlgorithmMode::kMatching:
{
algorithms.push_back(get_cublas_gemm_algo(
op_desc.tile_description.threadblock_shape.m(),
op_desc.tile_description.threadblock_shape.n(),
op_desc.tile_description.threadblock_shape.k(),
opcode_class));
break;
}
case AlgorithmMode::kBest:
{
// Choose first enumerated mode. If none are enumerated, choose based on opcode class
// and evaluate all of them.
if (options.library.algorithms.empty()) {
// Enumerate all algorithms
if (opcode_class == library::OpcodeClassID::kSimt) {
for (int algo = CUBLAS_GEMM_DEFAULT;
algo <= CUBLAS_GEMM_ALGO23;
++algo) {
algorithms.push_back(cublasGemmAlgo_t(algo));
}
}
else {
for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP;
++algo) {
algorithms.push_back(cublasGemmAlgo_t(algo));
}
}
}
else {
// Use the listed algorithms
algorithms.reserve(options.library.algorithms.size());
for (int algo : options.library.algorithms) {
algorithms.push_back(reinterpret_cast<cublasGemmAlgo_t const &>(algo));
}
}
break;
}
case AlgorithmMode::kDefault:
{
// Use the library's default algorithm
algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ?
CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
break;
}
default:
{
break;
}
}
}
/// Dispatcher to cublasGemmEx()
struct cublasGemmExDispatcher {
//
// Data members
//
library::GemmConfiguration configuration;
library::GemmArguments arguments;
cublasOperation_t trans_A;
cublasOperation_t trans_B;
cudaDataType_t data_type_A;
cudaDataType_t data_type_B;
cudaDataType_t data_type_C;
cudaDataType_t compute_type;
cublasGemmAlgo_t algo;
Status status;
//
// Methods
//
cublasGemmExDispatcher(
library::GemmDescription const &op_desc,
library::GemmConfiguration configuration_,
library::GemmArguments arguments_,
cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT
):
configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) {
trans_A = get_cublas_transpose_operation(op_desc.A.layout);
trans_B = get_cublas_transpose_operation(op_desc.B.layout);
bool good = true;
good = (good && get_cublas_datatype(data_type_A, op_desc.A.element));
good = (good && get_cublas_datatype(data_type_B, op_desc.B.element));
good = (good && get_cublas_datatype(data_type_C, op_desc.C.element));
good = (good && get_cublas_datatype(
compute_type,
op_desc.tile_description.math_instruction.element_accumulator));
if (!good) {
status = Status::kErrorNotSupported;
}
}
/// Executes GEMM using these arguments
cublasStatus_t operator()(cublasHandle_t handle) {
return cublasGemmEx(
handle,
trans_A,
trans_B,
configuration.problem_size.m(),
configuration.problem_size.n(),
configuration.problem_size.k(),
arguments.alpha,
arguments.A,
data_type_A,
int(configuration.lda),
arguments.B,
data_type_B,
int(configuration.ldb),
arguments.beta,
arguments.D,
data_type_C,
int(configuration.ldc),
compute_type,
algo
);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
#endif // CUTLASS_ENABLE_CUBLAS
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Verifies CUTLASS against references
@ -632,14 +507,16 @@ bool GemmOperationProfiler::verify_with_cublas_(
library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const &>(operation->description());
//
// Construct cuBLAS operators
//
CublasCreate handle;
cublasStatus_t status = handle.get_cublas_create_status();
if (status != CUBLAS_STATUS_SUCCESS) {
results_.back().status = get_cutlass_status(status);
results_.back().disposition = Disposition::kFailed;
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
return true;
}
@ -682,7 +559,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
);
if (gemm_op.status != Status::kSuccess) {
results_.back().disposition = Disposition::kNotVerified;
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
return true;
}
@ -692,8 +570,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
// Handle errors
if (status != CUBLAS_STATUS_SUCCESS) {
results_.back().status = get_cutlass_status(status);
results_.back().disposition = Disposition::kNotVerified;
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
return true;
}
@ -701,7 +579,7 @@ bool GemmOperationProfiler::verify_with_cublas_(
// Verify results
//
results_.back().disposition = compare_tensors(
results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors(
options,
*gemm_workspace_.Computed,
*gemm_workspace_.Reference
@ -709,19 +587,18 @@ bool GemmOperationProfiler::verify_with_cublas_(
// Save workspace if incorrect
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
results_.back().disposition == Disposition::kIncorrect) {
results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) {
save_workspace(
device_context,
options,
gemm_desc,
Provider::kCUTLASS,
Provider::kCUBLAS);
library::Provider::kCUTLASS,
library::Provider::kCUBLAS);
}
}
catch (...) {
results_.back().disposition = Disposition::kFailed;
results_.back().status = Status::kErrorNotSupported;
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
}
#endif
@ -741,7 +618,7 @@ bool GemmOperationProfiler::profile(
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
if (options.profiling.provider_enabled(Provider::kCUTLASS)) {
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
// Initialize structure containing GEMM arguments
gemm_workspace_.arguments.A = gemm_workspace_.A->data();

View File

@ -35,6 +35,7 @@
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "cutlass/library/manifest.h"
// Profiler includes

View File

@ -225,7 +225,7 @@ int OperationProfiler::profile_all(
ProblemSpace problem_space(arguments_, options.cmdline);
// 1. Construct performance report
PerformanceReport report(options, problem_space.argument_names());
PerformanceReport report(options, problem_space.argument_names(), kind_);
// 2. For each problem in problem space
ProblemSpace::Iterator problem_it = problem_space.begin();
@ -269,7 +269,7 @@ int OperationProfiler::profile_all(
if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) {
continue;
}
// A. Initialize configuration
Status status = this->initialize_configuration(
options,
@ -278,7 +278,7 @@ int OperationProfiler::profile_all(
operation,
problem_space,
problem);
if (status == Status::kErrorInternal) {
// Stop profiling if there was an internal error
return false;
@ -341,7 +341,7 @@ int OperationProfiler::profile_all(
device_context,
options,
operation->description(),
Provider::kCUTLASS);
library::Provider::kCUTLASS);
}
//
@ -434,8 +434,8 @@ void OperationProfiler::save_workspace(
DeviceContext &device_context,
Options const &options,
library::OperationDescription const &desc,
Provider provider,
Provider verification_provider) {
library::Provider provider,
library::Provider verification_provider) {
for (auto const & named_allocation : device_context) {
@ -443,10 +443,10 @@ void OperationProfiler::save_workspace(
std::stringstream filename;
filename << desc.name << "_" << to_string(provider) << "_";
filename << desc.name << "_" << library::to_string(provider) << "_";
if (verification_provider != Provider::kInvalid) {
filename << "verified_by_" << to_string(verification_provider) << "_";
if (verification_provider != library::Provider::kInvalid) {
filename << "verified_by_" << library::to_string(verification_provider) << "_";
}
filename << named_allocation.first + ".mat";
@ -454,6 +454,7 @@ void OperationProfiler::save_workspace(
std::ofstream out(filename.str());
allocation->write_tensor_csv(out);
out << "\n";
if (options.report.verbose) {
std::cout << "wrote '" << filename.str() << "'" << std::endl;

View File

@ -35,6 +35,7 @@
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "cutlass/library/manifest.h"
// Profiler includes
@ -43,6 +44,7 @@
#include "performance_result.h"
#include "performance_report.h"
#include "problem_space.h"
#include "debug.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -192,8 +194,8 @@ public:
DeviceContext &device_context,
Options const &options,
library::OperationDescription const &desc,
Provider provider,
Provider verification_provider = Provider::kInvalid);
library::Provider provider,
library::Provider verification_provider = library::Provider::kInvalid);
protected:

View File

@ -31,6 +31,8 @@
#include "cutlass/cutlass.h"
#include "cutlass/version.h"
#include "cutlass/library/util.h"
#include "options.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -161,24 +163,30 @@ Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) {
if (cmdline.check_cmd_line_flag("initialization-provider")) {
std::string str;
cmdline.get_cmd_line_argument("initialization-provider", str);
provider = from_string<Provider>(str);
if (provider == Provider::kInvalid) {
provider = library::from_string<library::Provider>(str);
if (provider == library::Provider::kInvalid) {
enabled = false;
}
else if (provider != Provider::kReferenceHost && provider != Provider::kReferenceDevice) {
else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) {
throw std::runtime_error("Unsupported intialization provider specified.");
}
}
else {
provider = Provider::kReferenceDevice;
provider = library::Provider::kReferenceDevice;
}
cmdline.get_cmd_line_argument("seed", seed, 2019);
if (cmdline.check_cmd_line_flag("dist")) {
// user has set the data distribution (fix data distribution once set)
fix_data_distribution = true;
// set user provided data distribution
get_distribution(cmdline, "dist", data_distribution);
}
else {
// profiler choosen data distribution (allowed to change based on numeric types)
fix_data_distribution = false;
// set uniform data distribution with range [-4, 4]
data_distribution.set_uniform(-4, 4, 0);
}
@ -372,12 +380,12 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) {
providers.clear();
for (auto const &token : tokens) {
providers.push_back(from_string<Provider>(token));
providers.push_back(library::from_string<library::Provider>(token));
}
}
else {
providers.push_back(Provider::kCUTLASS);
providers.push_back(Provider::kCUBLAS);
providers.push_back(library::Provider::kCUTLASS);
providers.push_back(library::Provider::kCUBLAS);
}
}
@ -412,18 +420,18 @@ void Options::Profiling::print_options(std::ostream &out, int indent) const {
int j = 0;
for (auto const & provider : providers) {
out << (j++ ? ", " : "") << to_string(provider);
out << (j++ ? ", " : "") << library::to_string(provider);
}
out << "]\n";
}
/// Returns true if a provider is enabled
bool Options::Profiling::provider_enabled(Provider provider) const {
bool Options::Profiling::provider_enabled(library::Provider provider) const {
return std::find(providers.begin(), providers.end(), provider) != providers.end();
}
/// Returns the index of a provider if its enabled
size_t Options::Profiling::index(Provider provider) const {
size_t Options::Profiling::index(library::Provider provider) const {
size_t idx = 0;
for (auto const & x : providers) {
if (x == provider) {
@ -461,14 +469,14 @@ Options::Verification::Verification(cutlass::CommandLine const &cmdline) {
providers.clear();
for (auto const &token : tokens) {
Provider provider = from_string<Provider>(token);
if (provider != Provider::kInvalid) {
library::Provider provider = library::from_string<library::Provider>(token);
if (provider != library::Provider::kInvalid) {
providers.push_back(provider);
}
}
}
else {
providers.push_back(Provider::kCUBLAS);
providers.push_back(library::Provider::kCUBLAS);
}
}
@ -504,18 +512,18 @@ void Options::Verification::print_options(std::ostream &out, int indent) const {
int j = 0;
for (auto const & provider : providers) {
out << (j++ ? ", " : "") << to_string(provider);
out << (j++ ? ", " : "") << library::to_string(provider);
}
out << "]\n";
}
/// Returns true if a provider is enabled
bool Options::Verification::provider_enabled(Provider provider) const {
bool Options::Verification::provider_enabled(library::Provider provider) const {
return std::find(providers.begin(), providers.end(), provider) != providers.end();
}
/// Returns the index of a provider if its enabled
size_t Options::Verification::index(Provider provider) const {
size_t Options::Verification::index(library::Provider provider) const {
size_t idx = 0;
for (auto const & x : providers) {
if (x == provider) {
@ -658,7 +666,7 @@ Options::Options(cutlass::CommandLine const &cmdline):
// Prevent launches on the device for anything other than CUTLASS operation
if (execution_mode == ExecutionMode::kTrace) {
initialization.provider = Provider::kReferenceHost;
initialization.provider = library::Provider::kReferenceHost;
verification.enabled = false;
profiling.enabled = false;
}

View File

@ -105,11 +105,15 @@ public:
/// allocating tensors.
bool enabled;
/// If true, data distribution is set by the user and is not allowed to change
/// If false, data distribution is allowed to change based on element_type (library::NumericTypeID)
bool fix_data_distribution;
/// Data distribution for input tensors
Distribution data_distribution;
/// Source of random tensor elements
Provider provider;
library::Provider provider;
/// Random number generator seed.
int seed;
@ -162,10 +166,10 @@ public:
void print_options(std::ostream &out, int indent = 0) const;
/// Returns true if a provider is enabled
bool provider_enabled(Provider provider) const;
bool provider_enabled(library::Provider provider) const;
/// Returns the index of a provider if its enabled
size_t index(Provider provider) const;
size_t index(library::Provider provider) const;
};
/// Options related to profiling
@ -196,10 +200,10 @@ public:
void print_options(std::ostream &out, int indent = 0) const;
/// Returns true if a provider is enabled
bool provider_enabled(Provider provider) const;
bool provider_enabled(library::Provider provider) const;
/// Returns the index of a provider if its enabled
size_t index(Provider provider) const;
size_t index(library::Provider provider) const;
};
/// Options related to reporting

View File

@ -29,9 +29,15 @@
#include <iostream>
#include <stdexcept>
#include <iomanip>
#include <algorithm>
#include <cstring>
#include "cutlass/library/util.h"
#include "cutlass/library/util.h"
#include "performance_report.h"
#include "debug.h"
namespace cutlass {
namespace profiler {
@ -57,12 +63,17 @@ namespace profiler {
PerformanceReport::PerformanceReport(
Options const &options,
std::vector<std::string> const &argument_names
std::vector<std::string> const &argument_names,
library::OperationKind const &op_kind
):
options_(options), argument_names_(argument_names), problem_index_(0), good_(true) {
options_(options), argument_names_(argument_names), problem_index_(0), good_(true), op_kind_(op_kind) {
std::string file_name = options_.report.output_path.substr(0, options_.report.output_path.rfind("."));
std::string file_extension = options_.report.output_path.substr(options_.report.output_path.rfind(".") + 1);
op_file_name_ = file_name + "." + to_string(op_kind_) + "." + file_extension;
//
// Open output file
// Open output file for operation of PerformanceReport::op_kind
//
if (!options_.report.output_path.empty()) {
@ -70,17 +81,17 @@ PerformanceReport::PerformanceReport(
if (options_.report.append) {
std::ifstream test_output_file(options_.report.output_path.c_str());
std::ifstream test_output_file(op_file_name_);
if (test_output_file.is_open()) {
print_header = false;
test_output_file.close();
}
output_file_.open(options_.report.output_path.c_str(), std::ios::app);
output_file_.open(op_file_name_, std::ios::app);
}
else {
output_file_.open(options_.report.output_path.c_str());
output_file_.open(op_file_name_);
}
if (!output_file_.good()) {
@ -148,7 +159,7 @@ void PerformanceReport::close() {
}
}
else if (output_file_.is_open() && options_.report.verbose) {
std::cout << "\n\nWrote results to '" << options_.report.output_path << "'" << std::endl;
std::cout << "\n\nWrote results to '" << op_file_name_ << "'" << std::endl;
}
}
@ -184,19 +195,30 @@ std::ostream & PerformanceReport::print_result_pretty_(
out
<< "\n"
<< " Provider: " << SHELL_COLOR_BRIGHT() << to_string(result.provider, true) << SHELL_COLOR_END() << "\n"
<< " Operation: " << result.operation_name << "\n\n"
<< " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n"
<< " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n";
<< " Provider: " << SHELL_COLOR_BRIGHT() << library::to_string(result.provider, true) << SHELL_COLOR_END() << "\n"
<< " Operation: " << result.operation_name << "\n\n"
<< " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n"
<< " Verification: " << SHELL_COLOR_BRIGHT() << (options_.verification.enabled ? "ON":"OFF") << SHELL_COLOR_END() << "\n"
<< " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n\n";
// Display individual verification results for each verification-provider
if (options_.verification.enabled) {
static int const indent_spaces = 22;
for(auto & m : result.verification_map) {
out << std::right << std::setw(indent_spaces) << library::to_string(m.first, true) << ": " << to_string(m.second, true) << "\n";
}
}
out
<< "\n Arguments: ";
<< "\n Arguments: ";
int column_idx = 0;
for (auto const &arg : result.arguments) {
if (!arg.second.empty()) {
out << " --" << arg.first << "=" << arg.second;
column_idx += 4 + arg.first.size() + arg.second.size();
column_idx += int(4 + arg.first.size() + arg.second.size());
if (column_idx > 90) {
out << " \\\n ";
column_idx = 0;
@ -206,15 +228,15 @@ std::ostream & PerformanceReport::print_result_pretty_(
out << "\n\n";
out
<< " Bytes: " << result.bytes << " bytes\n"
<< " FLOPs: " << result.flops << " flops\n\n";
<< " Bytes: " << result.bytes << " bytes\n"
<< " FLOPs: " << result.flops << " flops\n\n";
if (result.good()) {
out
<< " Runtime: " << result.runtime << " ms\n"
<< " Memory: " << result.gbytes_per_sec() << " GiB/s\n"
<< "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n";
<< " Runtime: " << result.runtime << " ms\n"
<< " Memory: " << result.gbytes_per_sec() << " GiB/s\n"
<< "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n";
}

View File

@ -31,10 +31,14 @@
#include <vector>
#include <fstream>
// CUTLASS Profiler includes
#include "options.h"
#include "enumerated_types.h"
#include "performance_result.h"
// CUTLASS Library includes
#include "cutlass/library/library.h"
namespace cutlass {
namespace profiler {
@ -46,6 +50,12 @@ private:
/// Reference to options
Options const &options_;
/// Operation kind
library::OperationKind op_kind_;
/// Operation file name containing performance report of op_kind
std::string op_file_name_;
/// Output file containing results
std::ofstream output_file_;
@ -63,7 +73,7 @@ private:
public:
PerformanceReport(Options const &options, std::vector<std::string> const &argument_names);
PerformanceReport(Options const &options, std::vector<std::string> const &argument_names, library::OperationKind const &op_kind);
bool good() const { return good_; }

View File

@ -32,8 +32,12 @@
#include "cutlass/cutlass.h"
// CUTLASS Profiler includes
#include "enumerated_types.h"
// CUTLASS Library includes
#include "cutlass/library/library.h"
namespace cutlass {
namespace profiler {
@ -45,15 +49,22 @@ struct PerformanceResult {
/// Index of problem
size_t problem_index;
/// Provider
Provider provider;
/// library::Provider
library::Provider provider;
/// Outcome of test
Disposition disposition;
/// Operation kind
library::OperationKind op_kind;
/// CUTLASS status result from kernels
/// CUTLASS status result from kernels (success or failure)
// Status does information on verification
Status status;
/// Outcome of verification (worst case verification result)
Disposition disposition;
/// Outcome of verification (all verification results)
DispositionMap verification_map;
/// Operation object
std::string operation_name;
@ -76,7 +87,8 @@ struct PerformanceResult {
/// Ctor
PerformanceResult():
problem_index(0),
provider(Provider::kInvalid),
op_kind(library::OperationKind::kInvalid),
provider(library::Provider::kInvalid),
disposition(Disposition::kNotRun),
status(Status::kInvalid),
bytes(0),

View File

@ -27,10 +27,11 @@
*/
#include <string>
#include <iostream>
#include <stdexcept>
#include <sstream>
#include "cutlass/library/util.h"
#include "problem_space.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -849,17 +850,16 @@ bool arg_as_OpcodeClassID(
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null.
bool arg_as_scalar(
std::vector<uint8_t> &bytes,
library::NumericTypeID numeric_type,
KernelArgument::Value const *value_ptr) {
if (value_ptr->not_null) {
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) {
int64_t int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value;
// TODO - convert int64_t => destination type
}
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {

View File

@ -31,6 +31,12 @@ target_include_directories(
$<BUILD_INTERFACE:${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}>
)
target_link_libraries(
cutlass_tools_util_includes
INTERFACE
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:cublas>
)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
@ -40,3 +46,4 @@ install(
TARGETS cutlass_tools_util_includes
EXPORT NvidiaCutlass
)

View File

@ -119,6 +119,16 @@ struct CommandLine {
val = !(value == "0" || value == "false");
}
}
/**
* Obtains the value specified for a given commandline parameter --<flag>=<value>
*/
template <typename value_t>
void get_cmd_line_argument(const char* arg_name,
value_t& val) const {
get_cmd_line_argument(arg_name, val, val);
}
/**
* Obtains the value specified for a given commandline parameter --<flag>=<value>
@ -126,7 +136,7 @@ struct CommandLine {
template <typename value_t>
void get_cmd_line_argument(const char* arg_name,
value_t& val,
value_t const& _default = value_t()) const {
value_t const& _default) const {
using namespace std;
val = _default;

View File

@ -40,10 +40,14 @@ namespace device_memory {
/// Allocate a buffer of \p count elements of type \p T on the current CUDA device
template <typename T>
T* allocate(size_t count = 1) {
T* ptr = 0;
size_t bytes = sizeof(T) * count;
size_t bytes = 0;
bytes = count * sizeof(T);
cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);
if (cuda_error != cudaSuccess) {
throw cuda_exception("Failed to allocate memory", cuda_error);
}
@ -111,13 +115,16 @@ void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
copy_to_device(device_begin, &*begin, elements);
}
/******************************************************************************
* "Smart" device memory allocation
******************************************************************************/
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device_memory
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Device allocation abstraction that tracks size and capacity
template <typename T>
struct allocation {
class DeviceAllocation {
public:
/// Delete functor for CUDA device memory
struct deleter {
void operator()(T* ptr) {
@ -130,6 +137,7 @@ struct allocation {
}
};
public:
//
// Data members
//
@ -140,23 +148,55 @@ struct allocation {
/// Smart pointer
platform::unique_ptr<T, deleter> smart_ptr;
public:
//
// Static methods
//
/// Static member to compute the number of bytes needed for a given number of elements
static size_t bytes(size_t elements) {
if (sizeof_bits<T>::value < 8) {
size_t const kElementsPerByte = 8 / sizeof_bits<T>::value;
return elements / kElementsPerByte;
}
else {
size_t const kBytesPerElement = sizeof_bits<T>::value / 8;
return elements * kBytesPerElement;
}
}
public:
//
// Methods
//
/// Constructor: allocates no memory
allocation() : capacity(0) {}
DeviceAllocation() : capacity(0) {}
/// Constructor: allocates \p capacity elements on the current CUDA device
allocation(size_t _capacity) : smart_ptr(allocate<T>(_capacity)), capacity(_capacity) {}
DeviceAllocation(size_t _capacity) :
smart_ptr(device_memory::allocate<T>(_capacity)), capacity(_capacity) {}
/// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation
DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {}
/// Copy constructor
allocation(allocation const &p): smart_ptr(allocate<T>(p.capacity)), capacity(p.capacity) {
copy_device_to_device(smart_ptr.get(), p.get(), capacity);
DeviceAllocation(DeviceAllocation const &p):
smart_ptr(device_memory::allocate<T>(p.capacity)), capacity(p.capacity) {
device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
}
/// Move constructor
DeviceAllocation(DeviceAllocation &&p): capacity(0) {
std::swap(smart_ptr, p.smart_ptr);
std::swap(capacity, p.capacity);
}
/// Destructor
~allocation() { reset(); }
~DeviceAllocation() { reset(); }
/// Returns a pointer to the managed object
T* get() const { return smart_ptr.get(); }
@ -173,12 +213,41 @@ struct allocation {
smart_ptr.reset();
}
/// Deletes managed object, if owned, and allocates a new object
void reset(size_t _capacity) {
reset(device_memory::allocate<T>(_capacity), _capacity);
}
/// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
void reset(T* _ptr, size_t _capacity) {
smart_ptr.reset(_ptr);
capacity = _capacity;
}
/// Allocates a new buffer and copies the old buffer into it. The old buffer is then released.
void reallocate(size_t new_capacity) {
platform::unique_ptr<T, deleter> new_allocation(device_memory::allocate<T>(new_capacity));
device_memory::copy_device_to_device(
new_allocation.get(),
smart_ptr.get(),
std::min(new_capacity, capacity));
std::swap(smart_ptr, new_allocation);
std::swap(new_capacity, capacity);
}
/// Returns the number of elements
size_t size() const {
return capacity;
}
/// Returns the number of bytes needed to store the allocation
size_t bytes() const {
return bytes(capacity);
}
/// Returns a pointer to the object owned by *this
T* operator->() const { return smart_ptr.get(); }
@ -189,15 +258,69 @@ struct allocation {
const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
/// Copies a device-side memory allocation
allocation & operator=(allocation const &p) {
DeviceAllocation & operator=(DeviceAllocation const &p) {
if (capacity != p.capacity) {
smart_ptr.reset(allocate<T>(p.capacity));
smart_ptr.reset(device_memory::allocate<T>(p.capacity));
capacity = p.capacity;
}
copy_device_to_device(smart_ptr.get(), p.get(), capacity);
return *this;
}
/// Move assignment
DeviceAllocation & operator=(DeviceAllocation && p) {
std::swap(smart_ptr, p.smart_ptr);
std::swap(capacity, p.capacity);
return *this;
}
/// Copies the entire allocation from another location in device memory.
void copy_from_device(T const *ptr) const {
copy_from_device(ptr, capacity);
}
/// Copies a given number of elements from device memory
void copy_from_device(T const *ptr, size_t elements) const {
device_memory::copy_device_to_device(get(), ptr, elements);
}
void copy_to_device(T *ptr) const {
copy_to_device(ptr, capacity);
}
void copy_to_device(T *ptr, size_t elements) const {
device_memory::copy_device_to_device(ptr, get(), elements);
}
void copy_from_host(T const *ptr) const {
copy_from_host(ptr, capacity);
}
void copy_from_host(T const *ptr, size_t elements) const {
device_memory::copy_to_device(get(), ptr, elements);
}
void copy_to_host(T *ptr) const {
copy_to_host(ptr, capacity);
}
void copy_to_host(T *ptr, size_t elements) const {
device_memory::copy_to_host(ptr, get(), elements);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace device_memory {
/// Device allocation abstraction that tracks size and capacity
template <typename T>
using allocation = cutlass::DeviceAllocation<T>;
} // namespace device_memory
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -99,7 +99,7 @@ public:
using ConstReference = typename ConstTensorRef::Reference;
/// Used to handle packing of subbyte elements
static int const kElementsPerStoredItem = (sizeof_bits<Element>::value < 8 ? sizeof(Element) * 8 / sizeof_bits<Element>::value : 1);
static int const kElementsPerStoredItem = (sizeof_bits<Element>::value < 8 ? (8 / sizeof_bits<Element>::value) : 1);
private:
@ -232,7 +232,7 @@ public:
/// Returns the logical capacity based on extent and layout. May differ from size().
LongIndex capacity() const {
return layout_.capacity(extent_) * kElementsPerStoredItem;
return layout_.capacity(extent_);
}
/// Gets pointer to host data

View File

@ -0,0 +1,423 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*! \file
\brief HostTensor contributes management for both host and device memory.
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
for CUDA memcpy operations.
Call {host, device}_{data, ref, view}() for accessing host or device memory.
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
*/
#include <vector>
#include "cutlass/cutlass.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/tensor_ref_planar_complex.h"
#include "cutlass/tensor_view_planar_complex.h"
#include "device_memory.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Host tensor
template <
/// Data type of element stored within tensor (concept: NumericType)
typename Element_,
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
typename Layout_
>
class HostTensorPlanarComplex {
public:
/// Data type of individual access
using Element = Element_;
/// Mapping function from logical coordinate to linear memory
using Layout = Layout_;
/// Logical rank of tensor index space
static int const kRank = Layout::kRank;
/// Index type
using Index = typename Layout::Index;
/// Long index used for pointer offsets
using LongIndex = typename Layout::LongIndex;
/// Coordinate in logical tensor space
using TensorCoord = typename Layout::TensorCoord;
/// Layout's stride vector
using Stride = typename Layout::Stride;
/// Tensor reference to device memory
using TensorRef = TensorRefPlanarComplex<Element, Layout>;
/// Tensor reference to constant device memory
using ConstTensorRef = typename TensorRef::ConstTensorRef;
/// Tensor reference to device memory
using TensorView = TensorViewPlanarComplex<Element, Layout>;
/// Tensor reference to constant device memory
using ConstTensorView = typename TensorView::ConstTensorView;
/// Reference to element in tensor
using Reference = typename TensorRef::Reference;
/// Constant reference to element in tensor
using ConstReference = typename ConstTensorRef::Reference;
private:
//
// Data members
//
/// Extent of tensor in logical dimensions
TensorCoord extent_;
/// Layout object
Layout layout_;
/// Host-side memory allocation
std::vector<Element> host_;
/// Device-side memory
device_memory::allocation<Element> device_;
public:
//
// Device and Host Methods
//
/// Default constructor
HostTensorPlanarComplex() {}
/// Constructs a tensor given an extent. Assumes a packed layout
HostTensorPlanarComplex(
TensorCoord const &extent,
bool device_backed = true
) {
this->reset(extent, Layout::packed(extent), device_backed);
}
/// Constructs a tensor given an extent and layout
HostTensorPlanarComplex(
TensorCoord const &extent,
Layout const &layout,
bool device_backed = true
) {
this->reset(extent, layout, device_backed);
}
~HostTensorPlanarComplex() { }
/// Clears the HostTensor allocation to size/capacity = 0
void reset() {
extent_ = TensorCoord();
layout_ = Layout::packed(extent_);
host_.clear();
device_.reset();
}
/// Resizes internal memory allocations without affecting layout or extent
void reserve(
size_t count, ///< size of tensor in elements
bool device_backed_ = true) { ///< if true, device memory is also allocated
device_.reset();
host_.clear();
host_.resize(count * 2);
// Allocate memory
Element* device_memory = nullptr;
if (device_backed_) {
device_memory = device_memory::allocate<Element>(count * 2);
}
device_.reset(device_memory, device_backed_ ? count * 2 : 0);
}
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
/// extent and layout.
void reset(
TensorCoord const &extent, ///< extent of logical tensor
Layout const &layout, ///< layout object of tensor
bool device_backed_ = true) { ///< if true, device memory is also allocated.
extent_ = extent;
layout_ = layout;
reserve(size_t(layout_.capacity(extent_)), device_backed_);
}
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
/// extent and layout. Assumes a packed tensor configuration.
void reset(
TensorCoord const &extent, ///< extent of logical tensor
bool device_backed_ = true) { ///< if true, device memory is also allocated.
reset(extent, Layout::packed(extent), device_backed_);
}
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
/// To force allocation, call reset().
void resize(
TensorCoord const &extent, ///< extent of logical tensor
Layout const &layout, ///< layout object of tensor
bool device_backed_ = true) { ///< if true, device memory is also allocated.
extent_ = extent;
layout_ = layout;
LongIndex new_size = size_t(layout_.capacity(extent_));
if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
reserve(new_size);
}
}
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
void resize(
TensorCoord const &extent, ///< extent of logical tensor
bool device_backed_ = true) { ///< if true, device memory is also allocated.
resize(extent, Layout::packed(extent), device_backed_);
}
/// Returns the number of elements stored in the host tensor
size_t size() const {
return host_.size() / 2;
}
/// Returns the logical capacity based on extent and layout. May differ from size().
LongIndex capacity() const {
return layout_.capacity(extent_);
}
/// Stride between real and imaginary parts
LongIndex imaginary_stride() const {
return host_.size() / 2;
}
/// Gets pointer to host data
Element * host_data() { return host_.data(); }
/// Gets pointer to host data imaginary part
Element * host_data_imag() { return host_.data() + imaginary_stride(); }
/// Gets pointer to host data with a pointer offset
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
/// Gets pointer to host data with a pointer offset
Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
/// Gets a reference to an element in host memory
Reference host_data(LongIndex idx) {
return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
}
/// Gets pointer to host data
Element const * host_data() const { return host_.data(); }
/// Gets pointer to host data imaginary part
Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
/// Gets a constant reference to an element in host memory
ConstReference host_data(LongIndex idx) const {
return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
}
/// Gets pointer to device data
Element * device_data() { return device_.get(); }
/// Gets pointer to device data with a pointer offset
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
/// Gets pointer to device data
Element const * device_data() const { return device_.get(); }
/// Gets pointer to device data with a pointer offset
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
/// Accesses the tensor reference pointing to data
TensorRef host_ref(LongIndex ptr_element_offset=0) {
return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
}
/// Returns a tensor reference to the real part of the tensor
cutlass::TensorRef<Element, Layout> host_ref_real() {
return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
}
/// Returns a tensor reference to the real part of the tensor
cutlass::TensorRef<Element, Layout> host_ref_imag() {
return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
}
/// Accesses the tensor reference pointing to data
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
}
/// Accesses the tensor reference pointing to data
TensorRef device_ref(LongIndex ptr_element_offset=0) {
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
}
/// Accesses the tensor reference pointing to data
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
}
/// Returns a tensor reference to the real part of the tensor
cutlass::TensorRef<Element, Layout> device_ref_real() {
return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
}
/// Returns a tensor reference to the real part of the tensor
cutlass::TensorRef<Element, Layout> device_ref_imag() {
return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
}
/// Accesses the tensor reference pointing to data
TensorView host_view(LongIndex ptr_element_offset=0) {
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
}
/// Accesses the tensor reference pointing to data
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
}
/// Accesses the tensor reference pointing to data
cutlass::TensorView<Element, Layout> host_view_real() {
return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
}
/// Accesses the tensor reference pointing to data
cutlass::TensorView<Element, Layout> host_view_imag() {
return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
}
/// Accesses the tensor reference pointing to data
TensorView device_view(LongIndex ptr_element_offset=0) {
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
}
/// Accesses the tensor reference pointing to data
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
}
/// Accesses the tensor reference pointing to data
cutlass::TensorView<Element, Layout> device_view_real() {
return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
}
/// Accesses the tensor reference pointing to data
cutlass::TensorView<Element, Layout> device_view_imag() {
return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
}
/// Returns true if device memory is allocated
bool device_backed() const {
return (device_.get() == nullptr) ? false : true;
}
/// Returns the layout object
Layout layout() const {
return layout_;
}
/// Returns the layout object's stride vector
Stride stride() const {
return layout_.stride();
}
/// Returns the layout object's stride in a given physical dimension
Index stride(int dim) const {
return layout_.stride().at(dim);
}
/// Computes the offset of an index from the origin of the tensor
LongIndex offset(TensorCoord const& coord) const {
return layout_(coord);
}
/// Returns a reference to the element at the logical Coord in host memory
Reference at(TensorCoord const& coord) {
return host_data(offset(coord));
}
/// Returns a const reference to the element at the logical Coord in host memory
ConstReference at(TensorCoord const& coord) const {
return host_data(offset(coord));
}
/// Returns the extent of the tensor
TensorCoord extent() const {
return extent_;
}
/// Returns the extent of the tensor
TensorCoord & extent() {
return extent_;
}
/// Copies data from device to host
void sync_host() {
if (device_backed()) {
device_memory::copy_to_host(
host_data(), device_data(), imaginary_stride() * 2);
}
}
/// Copies data from host to device
void sync_device() {
if (device_backed()) {
device_memory::copy_to_device(
device_data(), host_data(), imaginary_stride() * 2);
}
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -0,0 +1,306 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for complex-valued GEMM in device code.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/complex.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/tensor_ref_planar_complex.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/tensor_view.h"
#include "cutlass/gemm/gemm.h"
namespace cutlass {
namespace reference {
namespace device {
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace kernel {
////////////////////////////////////////////////////////////////////////////////////////////////////
static int const kGemmPlanarComplexBlockSize = 4;
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ScalarType,
typename ComputeType,
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
typename InnerProductOp = multiply_add<complex<ComputeType>>
>
__global__ void GemmPlanarComplex(
gemm::GemmCoord problem_size,
complex<ScalarType> alpha,
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
ComplexTransform transform_a,
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
complex<ScalarType> beta,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
complex<ComputeType> initial_accum) {
int const kMblock = kGemmPlanarComplexBlockSize;
int const kNblock = kGemmPlanarComplexBlockSize;
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
// Note: batch is ignored.
int const M = problem_size.m();
int const N = problem_size.n();
int const K = problem_size.k();
ConvertOp convert_op;
InnerProductOp inner_product_op;
complex<ComputeType> accum[kMblock][kNblock];
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
accum[i][j] = initial_accum;
}
}
CUTLASS_PRAGMA_NO_UNROLL
for (int k_block = 0; k_block < K; ++k_block) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
if (row < M && col < N) {
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
complex<ComputeType> a = complex<ComputeType>{
ComputeType(a_ik.real()),
ComputeType(a_ik.imag())
};
complex<ComputeType> b = complex<ComputeType>{
ComputeType(b_kj.real()),
ComputeType(b_kj.imag())
};
if (transform_a == ComplexTransform::kConjugate) {
a = conj(a);
}
if (transform_b == ComplexTransform::kConjugate) {
b = conj(b);
}
accum[i][j] = inner_product_op(a, b, accum[i][j]);
}
}
}
}
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (row < M && col < N) {
complex<ScalarType> acc{
ScalarType(accum[i][j].real()),
ScalarType(accum[i][j].imag())
};
ComplexC c_ij = ComplexC();
if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
c_ij = tensor_c.at(coord);
}
complex<ScalarType> src{
ScalarType(c_ij.real()),
ScalarType(c_ij.imag())
};
complex<ScalarType> result = alpha * acc + beta * src;
ComplexC d_ij;
d_ij.real() = convert_op(result.real());
d_ij.imag() = convert_op(result.imag());;
tensor_d.at(coord) = d_ij;
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
///
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
/// AccumulatorType(0) as the last function argument can be easier than naming all template
/// arguments explicitly.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ScalarType,
typename ComputeType,
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
typename InnerProductOp = multiply_add<complex<ComputeType>>
>
void GemmPlanarComplex(
gemm::GemmCoord problem_size,
complex<ScalarType> alpha,
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
ComplexTransform transform_a,
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
complex<ScalarType> beta,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
complex<ComputeType> initial_accum) {
static_assert(
LayoutA::kRank == 2 &&
LayoutB::kRank == 2 &&
LayoutC::kRank == 2, "Tensors must be of rank 2");
int const kMblock = kernel::kGemmPlanarComplexBlockSize;
int const kNblock = kernel::kGemmPlanarComplexBlockSize;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
1);
kernel::GemmPlanarComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ScalarType,
ComputeType,
ConvertOp,
InnerProductOp
><<< grid, block >>>(
problem_size,
alpha,
tensor_a,
transform_a,
tensor_b,
transform_b,
beta,
tensor_c,
tensor_d,
initial_accum
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
///
/// This assumes the accumulator type is the same type as the scalars.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ScalarType
>
void GemmPlanarComplex(
gemm::GemmCoord problem_size,
complex<ScalarType> alpha,
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
ComplexTransform transform_a,
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
complex<ScalarType> beta,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
GemmPlanarComplex(
problem_size,
alpha,
tensor_a, transform_a,
tensor_b, transform_b,
beta,
tensor_c,
tensor_d,
complex<ScalarType>());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace reference
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -43,12 +43,12 @@
#endif
// CUDA includes
#include <cublas_v2.h>
#include <curand_kernel.h>
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/complex.h"
#include "cutlass/tensor_view.h"
#include "cutlass/util/reference/device/tensor_foreach.h"
@ -169,6 +169,95 @@ struct RandomGaussianFunc {
}
};
template <typename Real>
struct RandomGaussianFunc<complex<Real>> {
using Element = complex<Real>;
using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
/// Parameters structure
struct Params {
//
// Data members
//
uint64_t seed;
FloatType mean;
FloatType stddev;
int int_scale;
//
// Methods
//
/// Construction of Gaussian RNG functor.
Params(
uint64_t seed_ = 0,
Real mean_ = 0,
Real stddev_ = 1,
int int_scale_ = -1
):
seed(seed_),
mean(static_cast<FloatType>(mean_)),
stddev(static_cast<FloatType>(stddev_)),
int_scale(int_scale_) {
}
};
//
// Data members
//
/// Parameters object
Params params;
/// RNG state object
curandState_t rng_state;
//
// Methods
//
/// Device-side initialization of RNG
CUTLASS_DEVICE
RandomGaussianFunc(Params const &params): params(params) {
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(params.seed, gtid, 0, &rng_state);
}
/// Compute random value and update RNG state
CUTLASS_DEVICE
Element operator()() {
FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
rnd_r = params.mean + params.stddev * rnd_r;
rnd_i = params.mean + params.stddev * rnd_i;
Element result;
if (params.int_scale >= 0) {
rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale)));
rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale)));
result = {
Real(rnd_r / FloatType(IntType(1) << params.int_scale)),
Real(rnd_i / FloatType(IntType(1) << params.int_scale))
};
}
else {
result = Element(Real(rnd_r), Real(rnd_i));
}
return result;
}
};
/// Computes a random Gaussian distribution
template <
typename Element, ///< Element type
@ -269,12 +358,12 @@ template <typename Element> ///< Element type
void BlockFillRandomGaussian(
Element *ptr,
size_t capacity,
uint64_t seed, ///< seed for RNG
Element mean = Element(0), ///< Gaussian distribution's mean
Element stddev = Element(1), ///< Gaussian distribution's standard deviation
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
uint64_t seed, ///< seed for RNG
typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
using RandomFunc = detail::RandomGaussianFunc<Element>;
@ -383,6 +472,111 @@ struct RandomUniformFunc {
}
};
/// Computes a random Gaussian distribution
template <typename Real> ///< Layout function
struct RandomUniformFunc<complex<Real>> {
using Element = complex<Real>;
using FloatType = typename std::conditional<
(sizeof(Real) > 4),
double,
float>::type;
using IntType = typename std::conditional<
(sizeof(Real) > 4),
int64_t,
int>::type;
/// Parameters structure
struct Params {
//
// Data members
//
uint64_t seed;
FloatType range;
FloatType min;
int int_scale;
/// Default ctor
CUTLASS_HOST_DEVICE
Params() { }
//
// Methods
//
/// Construction of Gaussian RNG functor.
Params(
uint64_t seed_ = 0,
FloatType max = 1,
FloatType min_ = 0,
int int_scale_ = -1
):
seed(seed_),
range(static_cast<FloatType>(max - min_)),
min(static_cast<FloatType>(min_)),
int_scale(int_scale_) {
}
};
//
// Data members
//
/// Parameters object
Params params;
/// RNG state object
curandState_t rng_state;
//
// Methods
//
/// Device-side initialization of RNG
CUTLASS_DEVICE
RandomUniformFunc(Params const &params): params(params) {
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(params.seed, gtid, 0, &rng_state);
}
/// Compute random value and update RNG state
CUTLASS_DEVICE
Element operator()() {
FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
rnd_r = params.min + params.range * rnd_r;
rnd_i = params.min + params.range * rnd_i;
// Random values are cast to integer after scaling by a power of two to facilitate error
// testing
Element result;
if (params.int_scale >= 0) {
rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale)));
rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale)));
result = {
Real(rnd_r / FloatType(IntType(1) << params.int_scale)),
Real(rnd_i / FloatType(IntType(1) << params.int_scale))
};
}
else {
result = Element(Real(rnd_r), Real(rnd_i));
}
return result;
}
};
/// Computes a random Gaussian distribution
template <
typename Element, ///< Element type
@ -489,8 +683,8 @@ void BlockFillRandomUniform(
Element *ptr,
size_t capacity,
uint64_t seed, ///< seed for RNG
Element max = Element(1), ///< upper bound of distribution
Element min = Element(0), ///< lower bound for distribution
typename RealType<Element>::Type max, ///< upper bound of distribution
typename RealType<Element>::Type min, ///< lower bound for distribution
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
@ -976,13 +1170,15 @@ void BlockFillRandom(
uint64_t seed,
Distribution dist) {
using Real = typename RealType<Element>::Type;
if (dist.kind == Distribution::Gaussian) {
BlockFillRandomGaussian<Element>(
ptr,
capacity,
seed,
static_cast<Element>(dist.gaussian.mean),
static_cast<Element>(dist.gaussian.stddev),
static_cast<Real>(dist.gaussian.mean),
static_cast<Real>(dist.gaussian.stddev),
dist.int_scale);
}
else if (dist.kind == Distribution::Uniform) {
@ -990,8 +1186,8 @@ void BlockFillRandom(
ptr,
capacity,
seed,
static_cast<Element>(dist.uniform.max),
static_cast<Element>(dist.uniform.min),
static_cast<Real>(dist.uniform.max),
static_cast<Real>(dist.uniform.min),
dist.int_scale);
}
}

View File

@ -72,6 +72,7 @@ void GemmComplex(
ComplexTransform transform_b,
ScalarType beta,
TensorRef<ElementC, LayoutC> tensor_c,
TensorRef<ElementC, LayoutC> tensor_d,
ComputeType initial_accum) {
static_assert(
@ -138,7 +139,7 @@ void GemmComplex(
if (row < M && col < N) {
tensor_c.at(coord) = convert_op(
tensor_d.at(coord) = convert_op(
alpha * ScalarType(accum[i][j]) +
beta * ScalarType(tensor_c.at(coord)));
}
@ -171,9 +172,10 @@ void GemmComplex(
TensorRef<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
ScalarType beta,
TensorRef<ElementC, LayoutC> tensor_c) {
TensorRef<ElementC, LayoutC> tensor_c,
TensorRef<ElementC, LayoutC> tensor_d) {
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0));
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,223 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for complex-valued GEMM in host-side code.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_types.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/tensor_ref_planar_complex.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/tensor_view.h"
#include "cutlass/gemm/gemm.h"
namespace cutlass {
namespace reference {
namespace host {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
///
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
/// AccumulatorType(0) as the last function argument can be easier than naming all template
/// arguments explicitly.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ScalarType,
typename ComputeType,
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
typename InnerProductOp = multiply_add<complex<ComputeType>>
>
void GemmPlanarComplex(
gemm::GemmCoord problem_size,
complex<ScalarType> alpha,
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
ComplexTransform transform_a,
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
complex<ScalarType> beta,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
complex<ComputeType> initial_accum) {
static_assert(
LayoutA::kRank == 2 &&
LayoutB::kRank == 2 &&
LayoutC::kRank == 2, "Tensors must be of rank 2");
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
// Note: batch is ignored.
int const M = problem_size.m();
int const N = problem_size.n();
int const K = problem_size.k();
// Blocking necessary to speedup reference implementation
int const Mblock = 16;
int const Nblock = 16;
ConvertOp convert_op;
InnerProductOp inner_product_op;
for (int row_block = 0; row_block < M; row_block += Mblock) {
for (int col_block = 0; col_block < N; col_block += Nblock) {
complex<ComputeType> accum[Mblock][Nblock];
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
accum[i][j] = initial_accum;
}
}
for (int k_block = 0; k_block < K; ++k_block) {
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
int row = row_block + i;
int col = col_block + j;
if (row < M && col < N) {
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
complex<ComputeType> a = complex<ComputeType>{
ComputeType(a_ik.real()),
ComputeType(a_ik.imag())
};
complex<ComputeType> b = complex<ComputeType>{
ComputeType(b_kj.real()),
ComputeType(b_kj.imag())
};
if (transform_a == ComplexTransform::kConjugate) {
a = conj(a);
}
if (transform_b == ComplexTransform::kConjugate) {
b = conj(b);
}
accum[i][j] = inner_product_op(a, b, accum[i][j]);
}
}
}
}
for (int j = 0; j < Nblock; j++) {
for (int i = 0; i < Mblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (row < M && col < N) {
complex<ScalarType> acc{
ScalarType(accum[i][j].real()),
ScalarType(accum[i][j].imag())
};
ComplexC d_ij = tensor_c.at(coord);
complex<ScalarType> src{
ScalarType(d_ij.real()),
ScalarType(d_ij.imag())
};
complex<ScalarType> result = alpha * acc + beta * src;
d_ij.real() = convert_op(result.real());
d_ij.imag() = convert_op(result.imag());;
tensor_d.at(coord) = d_ij;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
///
/// This assumes the accumulator type is the same type as the scalars.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ScalarType
>
void GemmPlanarComplex(
gemm::GemmCoord problem_size,
complex<ScalarType> alpha,
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
ComplexTransform transform_a,
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
ComplexTransform transform_b,
complex<ScalarType> beta,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
GemmPlanarComplex(
problem_size,
alpha,
tensor_a, transform_a,
tensor_b, transform_b,
beta,
tensor_c,
tensor_d,
complex<ScalarType>());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace host
} // namespace reference
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -33,6 +33,9 @@
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/tensor_view.h"
#include "cutlass/tensor_view_planar_complex.h"
#include "cutlass/util/distribution.h"
//#include "cutlass/util/type_traits.h"
#include "tensor_foreach.h"
@ -112,6 +115,46 @@ bool TensorEquals(
return bool(func);
}
/// Returns true if two tensor views are equal.
template <
typename Element, ///< Element type
typename Layout> ///< Layout function
bool TensorEquals(
TensorViewPlanarComplex<Element, Layout> const &lhs,
TensorViewPlanarComplex<Element, Layout> const &rhs) {
// Extents must be identical
if (lhs.extent() != rhs.extent()) {
return false;
}
detail::TensorEqualsFunc<Element, Layout> real_func(
{lhs.data(), lhs.layout(), lhs.extent()},
{rhs.data(), rhs.layout(), rhs.extent()}
);
TensorForEach(
lhs.extent(),
real_func
);
if (!bool(real_func)) {
return false;
}
detail::TensorEqualsFunc<Element, Layout> imag_func(
{lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
{rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}
);
TensorForEach(
lhs.extent(),
imag_func
);
return bool(imag_func);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -137,6 +180,17 @@ bool TensorNotEquals(
return !bool(func);
}
/// Returns true if two tensor views are equal.
template <
typename Element, ///< Element type
typename Layout> ///< Layout function
bool TensorNotEquals(
TensorViewPlanarComplex<Element, Layout> const &lhs,
TensorViewPlanarComplex<Element, Layout> const &rhs) {
return !TensorEquals(lhs, rhs);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -38,6 +38,8 @@
#include "cutlass/complex.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_view.h"
#include "cutlass/tensor_view_planar_complex.h"
#include "cutlass/util/distribution.h"
#include "tensor_foreach.h"
@ -101,6 +103,18 @@ void TensorFill(
);
}
/// Fills a tensor with a uniform value
template <
typename Element, ///< Element type
typename Layout> ///< Layout function
void TensorFill(
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
cutlass::complex<Element> val = cutlass::complex<Element>(0)) { ///< value to uniformly fill it with
TensorFill(dst.view_real(), val.real());
TensorFill(dst.view_imag(), val.imag());
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -268,6 +282,23 @@ void TensorFillRandomGaussian(
);
}
/// Fills a tensor with random values with a Gaussian distribution.
template <
typename Element, ///< Element type
typename Layout> ///< Layout function
void TensorFillRandomGaussian(
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
uint64_t seed, ///< seed for RNG
double mean = 0, ///< Gaussian distribution's mean
double stddev = 1, ///< Gaussian distribution's standard deviation
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits);
TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with random values with a Gaussian distribution.
@ -461,6 +492,23 @@ void TensorFillRandomUniform(
);
}
/// Fills a tensor with random values with a uniform random distribution.
template <
typename Element, ///< Element type
typename Layout> ///< Layout function
void TensorFillRandomUniform(
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
uint64_t seed, ///< seed for RNG
double max = 1, ///< upper bound of distribution
double min = 0, ///< lower bound for distribution
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
TensorFillRandomUniform(dst.view_real(), seed, max, min, bits);
TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with random values with a uniform random distribution.
@ -774,6 +822,27 @@ void BlockFillSequential(
}
}
/// Fills a block of data with sequential elements
template <
typename Element
>
void BlockFillSequentialModN(
Element *ptr,
int64_t capacity,
int64_t mod,
int64_t v = int64_t(1),
int64_t s = int64_t(0)) {
int i = 0;
while (i < capacity) {
cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
8)>::get(ptr, i) = Element(s);
s = int64_t(s + v) % mod;
++i;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -26,6 +26,8 @@
#include "cutlass/core_io.h"
#include "cutlass/tensor_view.h"
#include "cutlass/tensor_view_planar_complex.h"
#include "cutlass/complex.h"
namespace cutlass {
@ -87,13 +89,13 @@ inline std::ostream & TensorView_WriteRank(
coord[rank] = idx;
if (rank + 2 == Layout::kRank) {
// Write least significant ranks asa matrix with rows delimited by ";\n"
out << (idx ? ";\n" : "");
// Write least significant ranks asa matrix with rows delimited by "\n"
out << (idx ? ",\n" : "");
TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width);
}
else {
// Higher ranks are separated by newlines
out << (idx ? "\n" : "");
out << (idx ? ",\n\n" : "");
TensorView_WriteRank(out, view, coord, rank + 1, width);
}
}
@ -101,6 +103,76 @@ inline std::ostream & TensorView_WriteRank(
return out;
}
/// Helper to write the least significant rank of a TensorView
template <
typename Element,
typename Layout
>
inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank(
std::ostream& out,
TensorViewPlanarComplex<Element, Layout> const& view,
Coord<Layout::kRank> const &start_coord,
int rank,
std::streamsize width) {
for (int idx = 0; idx < view.extent(rank); ++idx) {
Coord<Layout::kRank> coord(start_coord);
coord[rank] = idx;
if (idx) {
out.width(0);
out << ", ";
}
if (idx || coord) {
out.width(width);
}
complex<Element> x = view.at(coord);
out << x;
}
return out;
}
/// Helper to write a rank of a TensorView
template <
typename Element,
typename Layout
>
inline std::ostream & TensorViewPlanarComplex_WriteRank(
std::ostream& out,
TensorViewPlanarComplex<Element, Layout> const& view,
Coord<Layout::kRank> const &start_coord,
int rank,
std::streamsize width) {
// If called on the least significant rank, write the result as a row
if (rank + 1 == Layout::kRank) {
return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width);
}
// Otherwise, write a sequence of rows and newlines
for (int idx = 0; idx < view.extent(rank); ++idx) {
Coord<Layout::kRank> coord(start_coord);
coord[rank] = idx;
if (rank + 2 == Layout::kRank) {
// Write least significant ranks asa matrix with rows delimited by ";\n"
out << (idx ? ";\n" : "");
TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width);
}
else {
// Higher ranks are separated by newlines
out << (idx ? "\n" : "");
TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width);
}
}
return out;
}
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -143,4 +215,42 @@ inline std::ostream& operator<<(
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Prints human-readable representation of a TensorView to an ostream
template <
typename Element,
typename Layout
>
inline std::ostream& TensorViewWrite(
std::ostream& out,
TensorViewPlanarComplex<Element, Layout> const& view) {
// Prints a TensorView according to the following conventions:
// - least significant rank is printed as rows separated by ";\n"
// - all greater ranks are delimited with newlines
//
// The result is effectively a whitespace-delimited series of 2D matrices.
return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord<Layout::kRank>(), 0, out.width());
}
/// Prints human-readable representation of a TensorView to an ostream
template <
typename Element,
typename Layout
>
inline std::ostream& operator<<(
std::ostream& out,
TensorViewPlanarComplex<Element, Layout> const& view) {
// Prints a TensorView according to the following conventions:
// - least significant rank is printed as rows separated by ";\n"
// - all greater ranks are delimited with newlines
//
// The result is effectively a whitespace-delimited series of 2D matrices.
return TensorViewWrite(out, view);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass