CUTLASS 2.1 (#83)
CUTLASS 2.1 contributes: - BLAS-style host-side API added to CUTLASS Library - Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores - Minor enhancements and bug fixes
This commit is contained in:
@ -52,13 +52,15 @@ install(
|
||||
#
|
||||
|
||||
cutlass_add_library(
|
||||
cutlass_lib
|
||||
SHARED
|
||||
src/library.cu
|
||||
cutlass_library_objs
|
||||
OBJECT
|
||||
src/handle.cu
|
||||
src/manifest.cpp
|
||||
src/operation_table.cu
|
||||
src/singleton.cu
|
||||
src/util.cu
|
||||
|
||||
)
|
||||
add_library(nvidia::cutlass::library ALIAS cutlass_lib)
|
||||
set_target_properties(cutlass_lib PROPERTIES EXPORT_NAME library)
|
||||
|
||||
file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/*.py)
|
||||
|
||||
@ -66,16 +68,19 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU
|
||||
# auto-instantiation of CUTLASS kernels
|
||||
#
|
||||
|
||||
# set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit.
|
||||
set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
|
||||
execute_process(
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generator.py
|
||||
--operations all
|
||||
--operations "${CUTLASS_LIBRARY_OPERATIONS}"
|
||||
--build-dir ${PROJECT_BINARY_DIR}
|
||||
--curr-build-dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--generator-target library
|
||||
--architectures "${CUTLASS_NVCC_ARCHS_ENABLED}"
|
||||
--kernels "${CUTLASS_LIBRARY_KERNELS}"
|
||||
--cuda-version "${CMAKE_CUDA_COMPILER_VERSION}"
|
||||
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
|
||||
RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT
|
||||
OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log
|
||||
@ -95,35 +100,70 @@ else()
|
||||
endif()
|
||||
|
||||
target_include_directories(
|
||||
cutlass_lib
|
||||
cutlass_library_objs
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
)
|
||||
|
||||
set_target_properties(
|
||||
cutlass_lib
|
||||
PROPERTIES
|
||||
OUTPUT_NAME cutlass
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
cutlass_lib
|
||||
cutlass_library_objs
|
||||
PUBLIC
|
||||
cutlass_library_includes
|
||||
)
|
||||
|
||||
function(cutlass_add_cutlass_library)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs NAME TYPE EXPORT_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_add_library(
|
||||
${__NAME}
|
||||
${__TYPE}
|
||||
EXPORT_NAME ${__EXPORT_NAME}
|
||||
$<TARGET_OBJECTS:cutlass_library_objs>
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
${__NAME}
|
||||
PUBLIC
|
||||
cutlass_library_includes
|
||||
)
|
||||
|
||||
set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX})
|
||||
|
||||
set(OUTPUT_NAME cutlass)
|
||||
|
||||
if (WIN32 AND ${__TYPE} STREQUAL "STATIC")
|
||||
set(OUTPUT_NAME "${OUTPUT_NAME}.static")
|
||||
endif()
|
||||
|
||||
set_target_properties(
|
||||
${__NAME}
|
||||
PROPERTIES
|
||||
OUTPUT_NAME ${OUTPUT_NAME}
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
cutlass_add_cutlass_library(NAME cutlass_lib TYPE SHARED EXPORT_NAME library)
|
||||
cutlass_add_cutlass_library(NAME cutlass_library_static TYPE STATIC EXPORT_NAME library_static)
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/
|
||||
DESTINATION
|
||||
${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS cutlass_lib cutlass_library_includes
|
||||
TARGETS
|
||||
cutlass_lib
|
||||
cutlass_library_static
|
||||
cutlass_library_includes
|
||||
EXPORT NvidiaCutlass
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
284
tools/library/include/cutlass/library/handle.h
Normal file
284
tools/library/include/cutlass/library/handle.h
Normal file
@ -0,0 +1,284 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief BLAS-like handle used to launch operations on the CUDA device.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Handle object
|
||||
class Handle {
|
||||
private:
|
||||
|
||||
/// Host workspace
|
||||
static int const kHostWorkspaceSize = (4 << 10);
|
||||
|
||||
/// CUDA device properties
|
||||
cudaDeviceProp device_;
|
||||
|
||||
/// CUDA stream
|
||||
cudaStream_t stream_;
|
||||
|
||||
/// Device workspace
|
||||
void *workspace_;
|
||||
|
||||
/// Size of device workspace in bytes
|
||||
size_t workspace_size_;
|
||||
|
||||
/// Indicates whether scalars are host or device pointers
|
||||
ScalarPointerMode scalar_pointer_mode_;
|
||||
|
||||
/// Pointer to the most recently executed operation
|
||||
Operation const *last_operation_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20));
|
||||
|
||||
/// Destructor
|
||||
~Handle();
|
||||
|
||||
/// Move constructor
|
||||
Handle(Handle && handle);
|
||||
|
||||
/// Move assignment operator
|
||||
Handle &operator=(Handle && handle);
|
||||
|
||||
//
|
||||
// Persistent state accessors
|
||||
//
|
||||
|
||||
/// Returns compute capability of the selected device
|
||||
int compute_capability() const;
|
||||
|
||||
/// Sets the current CUDA stream
|
||||
void set_stream(cudaStream_t stream);
|
||||
|
||||
/// Gets the current CUDA stream
|
||||
cudaStream_t get_stream() const;
|
||||
|
||||
/// Gets the device workspace size
|
||||
size_t get_workspace_size() const;
|
||||
|
||||
/// Gets a pointer to the device workspace allocation in Global Memory
|
||||
void *get_workspace() const;
|
||||
|
||||
/// Sets the size of device workspace, invalidating calls to get_device_workspace()
|
||||
void set_workspace_size(size_t bytes);
|
||||
|
||||
/// Gets the scalar pointer mode
|
||||
ScalarPointerMode get_scalar_pointer_mode() const;
|
||||
|
||||
/// Sets the scalar pointer mode
|
||||
void set_scalar_pointer_mode(ScalarPointerMode mode);
|
||||
|
||||
/// Gets the most recently executed operation
|
||||
Operation const *get_last_operation() const;
|
||||
|
||||
//
|
||||
// Computations
|
||||
//
|
||||
|
||||
/// Executes a GEMM computation: D <= alpha * A*B + beta * C
|
||||
Status gemm(
|
||||
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
|
||||
|
||||
void const * ptr_A, /// Pointer to A matrix in Global Memory
|
||||
int lda, /// Leading dimension of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
|
||||
|
||||
void const * ptr_B, /// Pointer to B matrix in Global Memory
|
||||
int ldb, /// Leading dimension of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrices
|
||||
|
||||
void const * ptr_C, /// Pointer to C matrix
|
||||
int ldc, /// Leading dimension of C matrix
|
||||
|
||||
void * ptr_D, /// Pointer to D matrix
|
||||
int ldd /// Leading dimension of D matrix
|
||||
);
|
||||
|
||||
/// Planar complex GEMM
|
||||
///
|
||||
/// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel.
|
||||
///
|
||||
Status gemm_planar_complex(
|
||||
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
||||
|
||||
void const * ptr_A_real, /// Pointer to real part of A matrix
|
||||
void const * ptr_A_imag, /// Pointer to imaginary part of A matrix
|
||||
int lda_real, /// Leading dimension of real part of A matrix
|
||||
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
||||
|
||||
void const * ptr_B_real, /// Pointer to real part of B matrix
|
||||
void const * ptr_B_imag, /// Pointer to imaginary part of B matrix
|
||||
int ldb_real, /// Leading dimension of real part of B matrix
|
||||
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrix
|
||||
|
||||
void const * ptr_C_real, /// Pointer to real part of C matrix
|
||||
void const * ptr_C_imag, /// Pointer to imaginary part of C matrix
|
||||
int ldc_real, /// Leading dimension of real part of C matrix
|
||||
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
||||
|
||||
void * ptr_D_real, /// Pointer to real part of D matrix
|
||||
void * ptr_D_imag, /// Pointer to imaginary part of D matrix
|
||||
int ldd_real, /// Leading dimension of real part of D matrix
|
||||
int ldd_imag, /// Leading dimension of imaginary part of D matrix
|
||||
|
||||
int batch_count = 1, /// Number of batched GEMMs to execute
|
||||
|
||||
int64_t batch_stride_A_real = 0,
|
||||
int64_t batch_stride_A_imag = 0,
|
||||
|
||||
int64_t batch_stride_B_real = 0,
|
||||
int64_t batch_stride_B_imag = 0,
|
||||
|
||||
int64_t batch_stride_C_real = 0,
|
||||
int64_t batch_stride_C_imag = 0,
|
||||
|
||||
int64_t batch_stride_D_real = 0,
|
||||
int64_t batch_stride_D_imag = 0
|
||||
);
|
||||
|
||||
/// Planar complex GEMM loading pointers from arrays in global memory
|
||||
Status gemm_planar_complex_array(
|
||||
|
||||
int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid)
|
||||
int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid)
|
||||
int expected_K, /// Expected GEMM K dimension
|
||||
int batch_count, /// Number of independent GEMM computations to execute
|
||||
|
||||
int const *M, /// Array containing the GEMM M dimension for each batch index
|
||||
int const *N, /// Array containing the GEMM N dimension for each batch index
|
||||
int const *K, /// Array containing the GEMM K dimension for each batch index
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
||||
|
||||
void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices
|
||||
void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices
|
||||
|
||||
int lda_real, /// Leading dimension of real part of A matrix
|
||||
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
||||
|
||||
void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices
|
||||
void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices
|
||||
|
||||
int ldb_real, /// Leading dimension of real part of B matrix
|
||||
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrix
|
||||
|
||||
void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices
|
||||
void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices
|
||||
|
||||
int ldc_real, /// Leading dimension of real part of C matrix
|
||||
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
||||
|
||||
void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices
|
||||
void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices
|
||||
|
||||
int ldd_real, /// Leading dimension of real part of D matrix
|
||||
int ldd_imag /// Leading dimension of imaginary part of D matrix
|
||||
);
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Unique pointer storing the handle
|
||||
using HandlePtr = std::unique_ptr<Handle>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -68,6 +68,10 @@ enum class LayoutTypeID {
|
||||
kRowMajorInterleavedK4,
|
||||
kColumnMajorInterleavedK16,
|
||||
kRowMajorInterleavedK16,
|
||||
kColumnMajorInterleavedK32,
|
||||
kRowMajorInterleavedK32,
|
||||
kColumnMajorInterleavedK64,
|
||||
kRowMajorInterleavedK64,
|
||||
kTensorNCHW,
|
||||
kTensorNHWC,
|
||||
kInvalid
|
||||
@ -110,9 +114,21 @@ enum class NumericTypeID {
|
||||
/// Enumeraed type describing a transformation on a complex value.
|
||||
enum class ComplexTransform {
|
||||
kNone,
|
||||
kConjugate
|
||||
kConjugate,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/// Providers
|
||||
enum class Provider {
|
||||
kCUTLASS,
|
||||
kReferenceHost,
|
||||
kReferenceDevice,
|
||||
kCUBLAS,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Enumeration indicating the kind of operation
|
||||
enum class OperationKind {
|
||||
kGemm,
|
||||
@ -143,6 +159,14 @@ enum class OpcodeClassID {
|
||||
kInvalid
|
||||
};
|
||||
|
||||
enum class MathOperationID {
|
||||
kMultiplyAdd,
|
||||
kMultiplyAddSaturate,
|
||||
kMultiplyAddComplex,
|
||||
kXorPopc,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Enumeration indicating what kind of GEMM operation to perform
|
||||
@ -150,88 +174,20 @@ enum class GemmKind {
|
||||
kGemm,
|
||||
kBatched,
|
||||
kArray,
|
||||
kUniversal,
|
||||
kPlanarComplex,
|
||||
kPlanarComplexBatched,
|
||||
kPlanarComplexArray,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Lexical cast from string
|
||||
template <typename T> T from_string(std::string const &);
|
||||
|
||||
/// Converts a NumericType enumerant to a string
|
||||
char const *to_string(OperationKind type, bool pretty = false);
|
||||
|
||||
/// Parses a NumericType enumerant from a string
|
||||
template <> OperationKind from_string<OperationKind>(std::string const &str);
|
||||
|
||||
/// Converts a NumericType enumerant to a string
|
||||
char const *to_string(NumericTypeID type, bool pretty = false);
|
||||
|
||||
/// Parses a NumericType enumerant from a string
|
||||
template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
|
||||
|
||||
/// Returns the size of a data type in bits
|
||||
int sizeof_bits(NumericTypeID type);
|
||||
|
||||
/// Returns true if the numeric type is a complex data type or false if real-valued.
|
||||
bool is_complex_type(NumericTypeID type);
|
||||
|
||||
/// Returns the real-valued type underlying a type (only different from 'type' if complex)
|
||||
NumericTypeID get_real_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is integer
|
||||
bool is_integer_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is signed
|
||||
bool is_signed_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is a signed integer
|
||||
bool is_signed_integer(NumericTypeID type);
|
||||
|
||||
/// returns true if numeric type is an unsigned integer
|
||||
bool is_unsigned_integer(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is floating-point type
|
||||
bool is_float_type(NumericTypeID type);
|
||||
|
||||
/// To string method for cutlass::Status
|
||||
char const *to_string(Status status, bool pretty = false);
|
||||
|
||||
/// Converts a LayoutTypeID enumerant to a string
|
||||
char const *to_string(LayoutTypeID layout, bool pretty = false);
|
||||
|
||||
/// Parses a LayoutType enumerant from a string
|
||||
template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
|
||||
|
||||
/// Returns the rank of a layout's stride base on the LayoutTypeID
|
||||
int get_layout_stride_rank(LayoutTypeID layout_id);
|
||||
|
||||
/// Converts a OpcodeClassID enumerant to a string
|
||||
char const *to_string(OpcodeClassID type, bool pretty = false);
|
||||
|
||||
/// Converts a OpcodeClassID enumerant from a string
|
||||
template <>
|
||||
OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
|
||||
|
||||
/// Lexical cast from int64_t to string
|
||||
std::string lexical_cast(int64_t int_value);
|
||||
|
||||
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
|
||||
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
|
||||
|
||||
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
|
||||
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
|
||||
|
||||
/// Casts from a signed int64 to the destination type. Returns true if successful.
|
||||
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
|
||||
|
||||
/// Casts from an unsigned int64 to the destination type. Returns true if successful.
|
||||
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
|
||||
|
||||
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
|
||||
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
|
||||
/// Mode of GEMM
|
||||
enum class GemmUniversalMode {
|
||||
kGemm,
|
||||
kGemmSplitKParallel,
|
||||
kBatched,
|
||||
kArray,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -246,6 +202,9 @@ struct MathInstructionDescription {
|
||||
/// Classification of math instruction
|
||||
OpcodeClassID opcode_class;
|
||||
|
||||
/// Type of math operation performed
|
||||
MathOperationID math_operation;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -253,9 +212,13 @@ struct MathInstructionDescription {
|
||||
MathInstructionDescription(
|
||||
cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(),
|
||||
NumericTypeID element_accumulator = NumericTypeID::kInvalid,
|
||||
OpcodeClassID opcode_class = OpcodeClassID::kInvalid
|
||||
OpcodeClassID opcode_class = OpcodeClassID::kInvalid,
|
||||
MathOperationID math_operation = MathOperationID::kMultiplyAdd
|
||||
):
|
||||
instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
|
||||
instruction_shape(instruction_shape),
|
||||
element_accumulator(element_accumulator),
|
||||
opcode_class(opcode_class),
|
||||
math_operation(math_operation) {}
|
||||
|
||||
};
|
||||
|
||||
@ -306,6 +269,9 @@ struct OperationDescription {
|
||||
/// Unique identifier describing the operation
|
||||
char const * name;
|
||||
|
||||
/// Operation provider
|
||||
Provider provider;
|
||||
|
||||
/// Kind of operation
|
||||
OperationKind kind;
|
||||
|
||||
@ -317,6 +283,7 @@ struct OperationDescription {
|
||||
//
|
||||
OperationDescription(
|
||||
char const * name = "unknown",
|
||||
Provider Provider = Provider::kInvalid,
|
||||
OperationKind kind = OperationKind::kInvalid,
|
||||
TileDescription const & tile_description = TileDescription()
|
||||
):
|
||||
@ -340,10 +307,11 @@ struct TensorDescription {
|
||||
|
||||
/// log2() of the maximum value each relevant stride may have
|
||||
int log_stride_range;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TensorDescription(
|
||||
NumericTypeID element = NumericTypeID::kInvalid,
|
||||
LayoutTypeID layout = LayoutTypeID::kInvalid,
|
||||
@ -355,7 +323,7 @@ struct TensorDescription {
|
||||
layout(layout),
|
||||
alignment(alignment),
|
||||
log_extent_range(log_extent_range),
|
||||
log_stride_range(log_stride_range) { }
|
||||
log_stride_range(log_stride_range) { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -414,7 +382,7 @@ struct GemmDescription : public OperationDescription {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Base class for all device-wide operations
|
||||
/// Base class for all operations
|
||||
class Operation {
|
||||
public:
|
||||
|
||||
@ -435,7 +403,7 @@ public:
|
||||
virtual Status initialize(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const = 0;
|
||||
|
||||
virtual Status run(
|
||||
@ -443,6 +411,7 @@ public:
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const = 0;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -551,11 +520,18 @@ using GemmBatchedArguments = GemmArguments;
|
||||
struct GemmArrayConfiguration {
|
||||
|
||||
gemm::GemmCoord problem_size;
|
||||
|
||||
/// Leading dimension of A matrix
|
||||
int64_t lda;
|
||||
|
||||
int64_t const *lda;
|
||||
int64_t const *ldb;
|
||||
int64_t const *ldc;
|
||||
int64_t const *ldd;
|
||||
/// Leading dimension of B matrix
|
||||
int64_t ldb;
|
||||
|
||||
/// Leading dimension of C matrix
|
||||
int64_t ldc;
|
||||
|
||||
/// Leading dimension of D matrix
|
||||
int64_t ldd;
|
||||
|
||||
int batch_count;
|
||||
};
|
||||
@ -580,49 +556,98 @@ struct GemmArrayArguments {
|
||||
|
||||
struct GemmPlanarComplexConfiguration {
|
||||
|
||||
GemmUniversalMode mode;
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
int64_t ldd;
|
||||
int64_t lda_real;
|
||||
int64_t lda_imag;
|
||||
|
||||
int64_t imag_stride_A;
|
||||
int64_t imag_stride_B;
|
||||
int64_t imag_stride_C;
|
||||
int64_t imag_stride_D;
|
||||
int64_t ldb_real;
|
||||
int64_t ldb_imag;
|
||||
|
||||
int64_t ldc_real;
|
||||
int64_t ldc_imag;
|
||||
|
||||
int64_t ldd_real;
|
||||
int64_t ldd_imag;
|
||||
};
|
||||
|
||||
using GemmPlanarComplexArgments = GemmArguments;
|
||||
/// Arguments for planar complex GEMMs
|
||||
struct GemmPlanarComplexArguments {
|
||||
|
||||
void const *A_real;
|
||||
void const *A_imag;
|
||||
|
||||
void const *B_real;
|
||||
void const *B_imag;
|
||||
|
||||
void const *C_real;
|
||||
void const *C_imag;
|
||||
|
||||
void *D_real;
|
||||
void *D_imag;
|
||||
|
||||
void const *alpha;
|
||||
void const *beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
|
||||
int64_t batch_stride_A_real;
|
||||
int64_t batch_stride_A_imag;
|
||||
|
||||
int64_t batch_stride_B_real;
|
||||
int64_t batch_stride_B_imag;
|
||||
|
||||
int64_t batch_stride_C_real;
|
||||
int64_t batch_stride_C_imag;
|
||||
|
||||
int64_t batch_stride_D_real;
|
||||
int64_t batch_stride_D_imag;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Batched complex valued GEMM in which real and imaginary parts are separated by a stride
|
||||
//
|
||||
// OperationKind: Gemm
|
||||
// GemmKind: Planar complex batched
|
||||
//
|
||||
struct GemmPlanarComplexBatchedConfiguration {
|
||||
/// This is a special form of planar complex which loads pointers and problem size
|
||||
/// from memory.
|
||||
struct GemmPlanarComplexArrayConfiguration {
|
||||
|
||||
gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
int64_t ldd;
|
||||
int64_t lda_real;
|
||||
int64_t lda_imag;
|
||||
|
||||
int64_t imag_stride_A;
|
||||
int64_t imag_stride_B;
|
||||
int64_t imag_stride_C;
|
||||
int64_t imag_stride_D;
|
||||
int64_t ldb_real;
|
||||
int64_t ldb_imag;
|
||||
|
||||
int64_t batched_stride_A;
|
||||
int64_t batched_stride_B;
|
||||
int64_t batched_stride_C;
|
||||
int64_t batched_stride_D;
|
||||
int64_t ldc_real;
|
||||
int64_t ldc_imag;
|
||||
|
||||
int64_t ldd_real;
|
||||
int64_t ldd_imag;
|
||||
};
|
||||
|
||||
/// Arguments for planar complex GEMMs
|
||||
struct GemmPlanarComplexArrayArguments {
|
||||
|
||||
int const *M;
|
||||
int const *N;
|
||||
int const *K;
|
||||
|
||||
void const * const * A_real;
|
||||
void const * const * A_imag;
|
||||
void const * const * B_real;
|
||||
void const * const * B_imag;
|
||||
void const * const * C_real;
|
||||
void const * const * C_imag;
|
||||
void * const * D_real;
|
||||
void * const * D_imag;
|
||||
|
||||
void const * alpha;
|
||||
void const * beta;
|
||||
ScalarPointerMode pointer_mode;
|
||||
};
|
||||
|
||||
using GemmPlanarComplexBatchedArguments = GemmArguments;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -55,10 +55,14 @@ using OperationVector = std::vector<std::unique_ptr<Operation>>;
|
||||
class Manifest {
|
||||
private:
|
||||
|
||||
/// Operation provider
|
||||
Provider provider_;
|
||||
|
||||
/// Global list of operations
|
||||
OperationVector operations_;
|
||||
|
||||
public:
|
||||
Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { }
|
||||
|
||||
/// Top-level initialization
|
||||
Status initialize();
|
||||
|
||||
205
tools/library/include/cutlass/library/operation_table.h
Normal file
205
tools/library/include/cutlass/library/operation_table.h
Normal file
@ -0,0 +1,205 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
\file
|
||||
\brief Defines a data structure in which a set of functionally equivalent library::Operation
|
||||
instances may be queried.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tuple uniquely identifying functional behavior
|
||||
struct GemmFunctionalKey {
|
||||
|
||||
NumericTypeID element_compute;
|
||||
NumericTypeID element_scalar;
|
||||
NumericTypeID element_A;
|
||||
LayoutTypeID layout_A;
|
||||
ComplexTransform transform_A;
|
||||
NumericTypeID element_B;
|
||||
LayoutTypeID layout_B;
|
||||
ComplexTransform transform_B;
|
||||
NumericTypeID element_C;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
inline
|
||||
GemmFunctionalKey(
|
||||
NumericTypeID element_compute = NumericTypeID::kF32,
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32,
|
||||
NumericTypeID element_A = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
|
||||
ComplexTransform transform_A = ComplexTransform::kNone,
|
||||
NumericTypeID element_B = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
|
||||
ComplexTransform transform_B = ComplexTransform::kNone,
|
||||
NumericTypeID element_C = NumericTypeID::kF16
|
||||
):
|
||||
element_compute(element_compute),
|
||||
element_scalar(element_scalar),
|
||||
element_A(element_A),
|
||||
layout_A(layout_A),
|
||||
transform_A(transform_A),
|
||||
element_B(element_B),
|
||||
layout_B(layout_B),
|
||||
transform_B(transform_B),
|
||||
element_C(element_C)
|
||||
{ }
|
||||
|
||||
inline
|
||||
bool operator==(GemmFunctionalKey const &rhs) const {
|
||||
return
|
||||
(element_compute == rhs.element_compute) &&
|
||||
(element_scalar == rhs.element_scalar) &&
|
||||
(element_A == rhs.element_A) &&
|
||||
(layout_A == rhs.layout_A) &&
|
||||
(transform_A == rhs.transform_A) &&
|
||||
(element_B == rhs.element_B) &&
|
||||
(layout_B == rhs.layout_B) &&
|
||||
(transform_B == rhs.transform_B) &&
|
||||
(element_C == rhs.element_C);
|
||||
}
|
||||
|
||||
inline
|
||||
bool operator!=(GemmFunctionalKey const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Hash function for GemmFunctionalKey
|
||||
struct GemmFunctionalKeyHasher {
|
||||
using IntHash = std::hash<int>;
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8 - shl));
|
||||
}
|
||||
|
||||
inline
|
||||
size_t operator()(GemmFunctionalKey const &key) const {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
rotl(hash(int(key.element_compute)), 2) ^
|
||||
rotl(hash(int(key.element_scalar)), 3) ^
|
||||
rotl(hash(int(key.element_A)), 4) ^
|
||||
rotl(hash(int(key.layout_A)), 5) ^
|
||||
rotl(hash(int(key.transform_A)), 6) ^
|
||||
rotl(hash(int(key.element_B)), 7) ^
|
||||
rotl(hash(int(key.layout_B)), 8) ^
|
||||
rotl(hash(int(key.transform_B)), 9) ^
|
||||
rotl(hash(int(key.element_C)), 10);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes a partial ordering to search for GEMM operators
|
||||
struct GemmPreferenceKey {
|
||||
|
||||
int compute_capability;
|
||||
int alignment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
GemmPreferenceKey(): compute_capability(), alignment() { }
|
||||
|
||||
GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { }
|
||||
|
||||
bool operator<(GemmPreferenceKey const &rhs) const {
|
||||
return (compute_capability < rhs.compute_capability) ||
|
||||
((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment));
|
||||
}
|
||||
|
||||
bool operator==(GemmPreferenceKey const &rhs) const {
|
||||
return compute_capability == rhs.compute_capability;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Maps minimum compute capability onto a vector of possible operations
|
||||
using GemmOperationVectorMap = std::map<
|
||||
GemmPreferenceKey,
|
||||
std::vector<Operation const *>
|
||||
>;
|
||||
|
||||
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
|
||||
using GemmOperationFunctionalMap = std::unordered_map<
|
||||
GemmFunctionalKey,
|
||||
GemmOperationVectorMap,
|
||||
GemmFunctionalKeyHasher
|
||||
>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Table of cutlass::library::Operation instances
|
||||
class OperationTable {
|
||||
public:
|
||||
|
||||
/// Map of all operations of type kGemm and gemm_kind of type kGemm
|
||||
GemmOperationFunctionalMap gemm_operations;
|
||||
|
||||
/// Map of all operations of type kGemm and gemm_kind of type kPlanarComplex
|
||||
GemmOperationFunctionalMap gemm_planar_complex_operations;
|
||||
|
||||
/// Map of all operations of type kGemm and gemm_kind of type kPlanarComplexArray
|
||||
GemmOperationFunctionalMap gemm_planar_complex_array_operations;
|
||||
|
||||
public:
|
||||
|
||||
void append(Manifest const &manifest);
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k);
|
||||
|
||||
62
tools/library/include/cutlass/library/singleton.h
Normal file
62
tools/library/include/cutlass/library/singleton.h
Normal file
@ -0,0 +1,62 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Singleton instance stores a Manifest and Operation table
|
||||
class Singleton {
|
||||
public:
|
||||
|
||||
/// Manifest object
|
||||
Manifest manifest;
|
||||
|
||||
/// Operation table referencing the Manifest
|
||||
OperationTable operation_table;
|
||||
|
||||
public:
|
||||
|
||||
Singleton();
|
||||
|
||||
static Singleton const &get();
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
138
tools/library/include/cutlass/library/util.h
Normal file
138
tools/library/include/cutlass/library/util.h
Normal file
@ -0,0 +1,138 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
|
||||
\brief Utilities accompanying the CUTLASS library for interacting with Library types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Lexical cast from string
|
||||
template <typename T> T from_string(std::string const &);
|
||||
|
||||
/// Converts a Provider enumerant to a string
|
||||
char const *to_string(Provider provider, bool pretty = false);
|
||||
|
||||
/// Parses a Provider enumerant from a string
|
||||
template <> Provider from_string<Provider>(std::string const &str);
|
||||
|
||||
/// Converts a NumericType enumerant to a string
|
||||
char const *to_string(OperationKind type, bool pretty = false);
|
||||
|
||||
/// Parses a NumericType enumerant from a string
|
||||
template <> OperationKind from_string<OperationKind>(std::string const &str);
|
||||
|
||||
/// Converts a NumericType enumerant to a string
|
||||
char const *to_string(NumericTypeID type, bool pretty = false);
|
||||
|
||||
/// Parses a NumericType enumerant from a string
|
||||
template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
|
||||
|
||||
/// Returns the size of a data type in bits
|
||||
int sizeof_bits(NumericTypeID type);
|
||||
|
||||
/// Returns true if the numeric type is a complex data type or false if real-valued.
|
||||
bool is_complex_type(NumericTypeID type);
|
||||
|
||||
/// Returns the real-valued type underlying a type (only different from 'type' if complex)
|
||||
NumericTypeID get_real_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is integer
|
||||
bool is_integer_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is signed
|
||||
bool is_signed_type(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is a signed integer
|
||||
bool is_signed_integer(NumericTypeID type);
|
||||
|
||||
/// returns true if numeric type is an unsigned integer
|
||||
bool is_unsigned_integer(NumericTypeID type);
|
||||
|
||||
/// Returns true if numeric type is floating-point type
|
||||
bool is_float_type(NumericTypeID type);
|
||||
|
||||
/// To string method for cutlass::Status
|
||||
char const *to_string(Status status, bool pretty = false);
|
||||
|
||||
/// Converts a LayoutTypeID enumerant to a string
|
||||
char const *to_string(LayoutTypeID layout, bool pretty = false);
|
||||
|
||||
/// Parses a LayoutType enumerant from a string
|
||||
template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
|
||||
|
||||
/// Returns the rank of a layout's stride base on the LayoutTypeID
|
||||
int get_layout_stride_rank(LayoutTypeID layout_id);
|
||||
|
||||
/// Converts a OpcodeClassID enumerant to a string
|
||||
char const *to_string(OpcodeClassID type, bool pretty = false);
|
||||
|
||||
/// Converts a OpcodeClassID enumerant from a string
|
||||
template <>
|
||||
OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
|
||||
|
||||
/// Converts a ComplexTransform enumerant to a string
|
||||
char const *to_string(ComplexTransform type, bool pretty = false);
|
||||
|
||||
/// Converts a ComplexTransform enumerant from a string
|
||||
template <>
|
||||
ComplexTransform from_string<ComplexTransform>(std::string const &str);
|
||||
|
||||
/// Lexical cast from int64_t to string
|
||||
std::string lexical_cast(int64_t int_value);
|
||||
|
||||
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
|
||||
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
|
||||
|
||||
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
|
||||
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
|
||||
|
||||
/// Casts from a signed int64 to the destination type. Returns true if successful.
|
||||
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
|
||||
|
||||
/// Casts from an unsigned int64 to the destination type. Returns true if successful.
|
||||
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
|
||||
|
||||
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
|
||||
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -22,7 +22,9 @@ from library import *
|
||||
#
|
||||
class GemmOperation:
|
||||
#
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue):
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Cohort):
|
||||
|
||||
self.operation_kind = OperationKind.Gemm
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
@ -31,29 +33,75 @@ class GemmOperation:
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
|
||||
#
|
||||
def is_planar_complex(self):
|
||||
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def short_math_name(self):
|
||||
return ShortDataTypeNames[self.accumulator_type()]
|
||||
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
|
||||
math_operations_map = {
|
||||
MathOperation.xor_popc: 'xor',
|
||||
}
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
||||
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
else:
|
||||
inst_shape = ''
|
||||
|
||||
return "%s%s%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], inst_shape, GemmKindNames[self.gemm_kind])
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
@ -63,28 +111,32 @@ class GemmOperation:
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
||||
)
|
||||
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
if self.tile_description.stages > 2:
|
||||
threadblock = "%dx%d_%dx%d" % (
|
||||
self.tile_description.threadblock_shape[0],
|
||||
self.tile_description.threadblock_shape[1],
|
||||
self.tile_description.threadblock_shape[2],
|
||||
self.tile_description.stages
|
||||
)
|
||||
else:
|
||||
threadblock = "%dx%d" % (self.tile_description.threadblock_shape[0], self.tile_description.threadblock_shape[1])
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
|
||||
|
||||
return SubstituteTemplate(
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}",
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]),
|
||||
'layout': self.layout_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
@ -104,7 +156,7 @@ class EmitGemmInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.template = """
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
|
||||
${element_a}, ${layout_a},
|
||||
@ -116,14 +168,45 @@ class EmitGemmInstance:
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
${stages}
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
false,
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
self.gemm_complex_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${transform_a},
|
||||
${transform_b},
|
||||
${math_operation}
|
||||
${residual}
|
||||
>;
|
||||
"""
|
||||
|
||||
@ -135,6 +218,8 @@ class EmitGemmInstance:
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
@ -143,7 +228,7 @@ class EmitGemmInstance:
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
@ -157,57 +242,72 @@ class EmitGemmInstance:
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages)
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'residual': residual
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
|
||||
|
||||
return SubstituteTemplate(template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitGemmBatchedInstance:
|
||||
class EmitGemmPlanarComplexInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = cutlass::gemm::device::GemmBatched<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>;
|
||||
${math_operator}
|
||||
>::GemmKernel;
|
||||
|
||||
struct ${operation_name} : public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
#warp_shape[2] = operation.tile_description.math_instruction.instruction_shape[2]
|
||||
warp_shape[2] = operation.tile_description.threadblock_shape[2]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
@ -222,139 +322,89 @@ class EmitGemmBatchedInstance:
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
# Generator functions for all layouts
|
||||
#
|
||||
class EmitGemmPlanarComplexArrayInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
||||
${element_c}, cutlass::layout::RowMajor,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
||||
${element_c},
|
||||
${alignment_c},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,
|
||||
${stages},
|
||||
${math_operator}
|
||||
>::GemmArrayKernel;
|
||||
|
||||
struct ${operation_name} : public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
||||
|
||||
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
||||
transposed_layout_A = TransposedLayout[operation.A.layout]
|
||||
transposed_layout_B = TransposedLayout[operation.B.layout]
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.B.element],
|
||||
'layout_a': LayoutTag[transposed_layout_B],
|
||||
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
||||
'alignment_a': str(operation.B.alignment),
|
||||
'element_b': DataTypeTag[operation.A.element],
|
||||
'layout_b': LayoutTag[transposed_layout_A],
|
||||
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
||||
'alignment_b': str(operation.A.alignment),
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'alignment_c': str(operation.C.alignment),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateGemmSimt(gemm_kind, manifest, tile_descriptions, min_cc):
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
# for each tile configuration, emit a GEMM
|
||||
for tile in tile_descriptions:
|
||||
for layout in layouts:
|
||||
|
||||
A = TensorDescription(tile.math_instruction.element_a, layout[0], 1)
|
||||
B = TensorDescription(tile.math_instruction.element_b, layout[1], 1)
|
||||
C = TensorDescription(tile.math_instruction.element_accumulator, layout[2], 1)
|
||||
|
||||
manifest.append(GemmOperation(gemm_kind, 50, tile, A, B, C, tile.math_instruction.element_accumulator))
|
||||
|
||||
#
|
||||
def GenerateGemmTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]):
|
||||
|
||||
# Canonical matrix layouts
|
||||
canonical_layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
# Interleaved matrix layouts
|
||||
interleaved_layouts = {
|
||||
8: [
|
||||
#(LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
],
|
||||
4: [
|
||||
#(LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
}
|
||||
|
||||
# for each tile configuration, emit a GEMM
|
||||
for align in minimum_alignment:
|
||||
for tile in tile_descriptions:
|
||||
|
||||
min_input_size = min(DataTypeSize[tile.math_instruction.element_a], DataTypeSize[tile.math_instruction.element_a])
|
||||
|
||||
# If the data type is large enough, use canonical layouts.
|
||||
if min_input_size >= 16:
|
||||
layouts = canonical_layouts
|
||||
else:
|
||||
layouts = interleaved_layouts[min_input_size]
|
||||
|
||||
for layout in layouts:
|
||||
|
||||
#
|
||||
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
||||
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
||||
else [tile.math_instruction.element_accumulator,]
|
||||
|
||||
align_a = align // DataTypeSize[tile.math_instruction.element_a]
|
||||
align_b = align // DataTypeSize[tile.math_instruction.element_b]
|
||||
|
||||
|
||||
for output_type in output_types:
|
||||
|
||||
rows_per_warp = 8 // tile.warp_count[1]
|
||||
align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32)
|
||||
|
||||
A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a)
|
||||
B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b)
|
||||
C = TensorDescription(output_type, layout[2], max(1, align_c))
|
||||
|
||||
element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \
|
||||
else tile.math_instruction.element_accumulator
|
||||
|
||||
manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue))
|
||||
|
||||
|
||||
#
|
||||
def GenerateGemmWmmaTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]):
|
||||
|
||||
# Wmma supported matrix layouts
|
||||
layouts = [
|
||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor),
|
||||
]
|
||||
|
||||
# for each tile configuration, emit a GEMM
|
||||
for align in minimum_alignment:
|
||||
for tile in tile_descriptions:
|
||||
for layout in layouts:
|
||||
|
||||
#
|
||||
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
||||
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
||||
else [tile.math_instruction.element_accumulator,]
|
||||
|
||||
align_a = align // DataTypeSize[tile.math_instruction.element_a]
|
||||
align_b = align // DataTypeSize[tile.math_instruction.element_b]
|
||||
|
||||
|
||||
for output_type in output_types:
|
||||
|
||||
rows_per_warp = 8 // tile.warp_count[1]
|
||||
align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32)
|
||||
|
||||
A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a)
|
||||
B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b)
|
||||
C = TensorDescription(output_type, layout[2], max(1, align_c))
|
||||
|
||||
element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \
|
||||
else tile.math_instruction.element_accumulator
|
||||
|
||||
manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue))
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -369,21 +419,40 @@ class EmitGemmConfigurationLibrary:
|
||||
|
||||
self.instance_emitter = {
|
||||
GemmKind.Gemm: EmitGemmInstance,
|
||||
GemmKind.Batched: EmitGemmBatchedInstance
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
GemmKind.Gemm: 'GemmOperation',
|
||||
GemmKind.Batched: 'GemmBatchedOperation',
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
|
||||
self.instance_template = """
|
||||
self.instance_template = {
|
||||
GemmKind.Gemm: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.PlanarComplex: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
""",
|
||||
GemmKind.PlanarComplexArray: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${gemm_kind}<
|
||||
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by gemm_operation.py - Do not edit.
|
||||
@ -398,6 +467,14 @@ ${compile_guard_end}
|
||||
#include "library_internal.h"
|
||||
#include "gemm_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
@ -421,9 +498,11 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
self.configuration_file.write(self.header_template)
|
||||
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
@ -431,8 +510,10 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
emitter = self.instance_emitter[operation.gemm_kind]()
|
||||
|
||||
self.operations.append(operation)
|
||||
self.configuration_file.write(emitter.emit(operation))
|
||||
self.configuration_file.write(SubstituteTemplate(self.instance_template, {
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
|
||||
@ -443,6 +524,19 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
# Add wrapper objects within initialize() function
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -153,6 +153,68 @@ DataTypeSize = {
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class ComplexTransform(enum.Enum):
|
||||
none = enum.auto()
|
||||
conj = enum.auto()
|
||||
|
||||
#
|
||||
ComplexTransformTag = {
|
||||
ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
|
||||
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
|
||||
}
|
||||
|
||||
#
|
||||
RealComplexBijection = [
|
||||
(DataType.f16, DataType.cf16),
|
||||
(DataType.f32, DataType.cf32),
|
||||
(DataType.f64, DataType.cf64),
|
||||
]
|
||||
|
||||
#
|
||||
def is_complex(data_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if data_type == c:
|
||||
return True
|
||||
return False
|
||||
|
||||
#
|
||||
def get_complex_from_real(real_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if real_type == r:
|
||||
return c
|
||||
return DataType.invalid
|
||||
|
||||
#
|
||||
def get_real_from_complex(complex_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if complex_type == c:
|
||||
return r
|
||||
return DataType.invalid
|
||||
|
||||
#
|
||||
class ComplexMultiplyOp(enum.Enum):
|
||||
multiply_add = enum.auto()
|
||||
gaussian = enum.auto()
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class MathOperation(enum.Enum):
|
||||
multiply_add = enum.auto()
|
||||
multiply_add_saturate = enum.auto()
|
||||
xor_popc = enum.auto()
|
||||
multiply_add_complex = enum.auto()
|
||||
#
|
||||
MathOperationTag = {
|
||||
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
|
||||
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
|
||||
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
|
||||
MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class LayoutType(enum.Enum):
|
||||
ColumnMajor = enum.auto()
|
||||
@ -182,6 +244,17 @@ LayoutTag = {
|
||||
LayoutType.TensorNCxHW64: 'cutlass::layout::TensorNCxHW64'
|
||||
}
|
||||
|
||||
#
|
||||
TransposedLayout = {
|
||||
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
||||
LayoutType.RowMajor: LayoutType.ColumnMajor,
|
||||
LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
|
||||
LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
|
||||
LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
|
||||
LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
|
||||
LayoutType.TensorNHWC: LayoutType.TensorNHWC
|
||||
}
|
||||
|
||||
#
|
||||
ShortLayoutTypeNames = {
|
||||
LayoutType.ColumnMajor: 'n',
|
||||
@ -197,6 +270,14 @@ ShortLayoutTypeNames = {
|
||||
LayoutType.TensorNCxHW64: 'ncxhw64'
|
||||
}
|
||||
|
||||
#
|
||||
ShortComplexLayoutNames = {
|
||||
(LayoutType.ColumnMajor, ComplexTransform.none): 'n',
|
||||
(LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
|
||||
(LayoutType.RowMajor, ComplexTransform.none): 't',
|
||||
(LayoutType.RowMajor, ComplexTransform.conj): 'h'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
@ -244,9 +325,15 @@ ArchitectureNames = {
|
||||
#
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
text = re.sub(regex, value, text)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
###################################################################################################
|
||||
@ -256,28 +343,52 @@ class GemmKind(enum.Enum):
|
||||
Gemm = enum.auto()
|
||||
Batched = enum.auto()
|
||||
Array = enum.auto()
|
||||
Universal = enum.auto()
|
||||
PlanarComplex = enum.auto()
|
||||
PlanarComplexBatched = enum.auto()
|
||||
PlanarComplexArray = enum.auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
GemmKind.Gemm: "gemm",
|
||||
GemmKind.Batched: "gemm_batched",
|
||||
GemmKind.Array: "gemm_array",
|
||||
GemmKind.Universal: "gemm_universal",
|
||||
GemmKind.PlanarComplex: "gemm_planar_complex",
|
||||
GemmKind.PlanarComplexBatched: "gemm_planar_complex_batched",
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
}
|
||||
|
||||
#
|
||||
class EpilogueFunctor(enum.Enum):
|
||||
LinearCombination = enum.auto()
|
||||
LinearCombinationClamp = enum.auto()
|
||||
|
||||
#
|
||||
EpilogueFunctorTag = {
|
||||
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
|
||||
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
|
||||
}
|
||||
|
||||
#
|
||||
class SwizzlingFunctor(enum.Enum):
|
||||
Cohort = enum.auto()
|
||||
Identity = enum.auto()
|
||||
|
||||
#
|
||||
SwizzlingFunctorTag = {
|
||||
SwizzlingFunctor.Cohort: 'cutlass::gemm::threadblock::GemmCohortThreadblockSwizzle<${layout_a}, ${layout_b}>',
|
||||
SwizzlingFunctor.Identity: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle',
|
||||
}
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class MathInstruction:
|
||||
def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class):
|
||||
def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
|
||||
self.instruction_shape = instruction_shape
|
||||
self.element_a = element_a
|
||||
self.element_b = element_b
|
||||
self.element_accumulator = element_accumulator
|
||||
self.opcode_class = opcode_class
|
||||
self.math_operation = math_operation
|
||||
|
||||
|
||||
#
|
||||
@ -292,16 +403,14 @@ class TileDescription:
|
||||
self.maximum_compute_capability = max_compute
|
||||
|
||||
def procedural_name(self):
|
||||
if self.stages == 2:
|
||||
return "%dx%dx%d" % self.threadblock_shape
|
||||
elif self.stages > 2:
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
|
||||
#
|
||||
class TensorDescription:
|
||||
def __init__(self, element, layout, alignment = 1):
|
||||
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@ -114,6 +114,16 @@ class Manifest:
|
||||
self.args = args
|
||||
self.compute_capabilities = [int(x) for x in args.architectures.split(';')]
|
||||
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
]
|
||||
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
@ -142,6 +152,16 @@ void initialize_all(Manifest &manifest) {
|
||||
} // namespace cutlass
|
||||
|
||||
'''
|
||||
#
|
||||
def _filter_string_matches(self, filter_string, haystack):
|
||||
''' Returns true if all substrings appear in the haystack in order'''
|
||||
substrings = filter_string.split('*')
|
||||
for sub in substrings:
|
||||
idx = haystack.find(sub)
|
||||
if idx < 0:
|
||||
return False
|
||||
haystack = haystack[idx + len(sub):]
|
||||
return True
|
||||
|
||||
#
|
||||
def filter(self, operation):
|
||||
@ -159,6 +179,9 @@ void initialize_all(Manifest &manifest) {
|
||||
if not enabled:
|
||||
return False
|
||||
|
||||
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
|
||||
return False
|
||||
|
||||
# eliminate duplicates
|
||||
if operation.procedural_name() in self.operations_by_name.keys():
|
||||
return False
|
||||
@ -168,11 +191,10 @@ void initialize_all(Manifest &manifest) {
|
||||
name = operation.procedural_name()
|
||||
enabled = False
|
||||
for name_substr in self.kernel_names:
|
||||
if name_substr in name:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
enabled = True
|
||||
break
|
||||
|
||||
# todo: filter based on operation kind
|
||||
# todo: filter based on compute data type
|
||||
return enabled
|
||||
#
|
||||
@ -255,10 +277,11 @@ void initialize_all(Manifest &manifest) {
|
||||
manifest_path = os.path.join(generated_path, "manifest.cmake")
|
||||
with open(manifest_path, "w") as manifest_file:
|
||||
|
||||
target_name = 'cutlass_lib'
|
||||
target_name = 'cutlass_library_objs'
|
||||
|
||||
target_text = SubstituteTemplate("""cutlass_target_sources(
|
||||
${target_name}
|
||||
BATCH_SOURCES ON
|
||||
PRIVATE
|
||||
""", { 'target_name': target_name})
|
||||
|
||||
|
||||
@ -29,8 +29,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
#include "cutlass/gemm/device/gemm_batched.h"
|
||||
#include "cutlass/gemm/device/gemm_array.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
@ -68,8 +73,10 @@ public:
|
||||
GemmOperationBase(char const *name = "unknown_gemm") {
|
||||
|
||||
description_.name = name;
|
||||
description_.provider = Provider::kCUTLASS;
|
||||
description_.kind = OperationKind::kGemm;
|
||||
|
||||
description_.gemm_kind = GemmKind::kGemm;
|
||||
|
||||
description_.tile_description.threadblock_shape = make_Coord(
|
||||
Operator::ThreadblockShape::kM,
|
||||
Operator::ThreadblockShape::kN,
|
||||
@ -93,22 +100,23 @@ public:
|
||||
description_.tile_description.math_instruction.opcode_class =
|
||||
OpcodeClassMap<typename Operator::OperatorClass>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.math_operation =
|
||||
MathOperationMap<typename Operator::Operator>::kId;
|
||||
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag>::kMin;
|
||||
|
||||
description_.tile_description.maximum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag>::kMax;
|
||||
|
||||
description_.gemm_kind = GemmKind::kGemm;
|
||||
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.split_k_mode = Operator::kSplitKSerial ? SplitKMode::kSerial : SplitKMode::kNone;
|
||||
description_.transform_A = ComplexTransform::kNone;
|
||||
description_.transform_B = ComplexTransform::kNone;
|
||||
description_.split_k_mode = SplitKMode::kNone;
|
||||
description_.transform_A = ComplexTransformMap<Operator::kTransformA>::kId;
|
||||
description_.transform_B = ComplexTransformMap<Operator::kTransformB>::kId;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
@ -294,8 +302,24 @@ public:
|
||||
|
||||
return op->run(stream);
|
||||
}
|
||||
};
|
||||
|
||||
void print_operator_args(OperatorArguments &operator_args) const {
|
||||
#if 0
|
||||
std::cout << "GemmOperation::OperatorArguments" << std::endl;
|
||||
std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl;
|
||||
std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl;
|
||||
std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl;
|
||||
std::cout << " beta: " << operator_args.epilogue.beta << std::endl;
|
||||
std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl;
|
||||
std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl;
|
||||
std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl;
|
||||
std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl;
|
||||
std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl;
|
||||
std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl;
|
||||
std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -360,6 +384,7 @@ protected:
|
||||
*static_cast<ElementCompute const *>(arguments->alpha),
|
||||
*static_cast<ElementCompute const *>(arguments->beta)
|
||||
);
|
||||
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
|
||||
@ -491,6 +516,593 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class GemmArrayOperation : public GemmOperationBase<Operator_> {
|
||||
public:
|
||||
|
||||
using Operator = Operator_;
|
||||
using ElementA = typename Operator::ElementA;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::ElementB;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
|
||||
protected:
|
||||
|
||||
///
|
||||
GemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
GemmArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
|
||||
|
||||
description_.gemm_kind = GemmKind::kArray;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmArrayConfiguration const *configuration) {
|
||||
|
||||
operator_args.problem_size = configuration->problem_size;
|
||||
|
||||
operator_args.batch_count = configuration->batch_count;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmArrayArguments const *arguments) {
|
||||
|
||||
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
*static_cast<ElementCompute const *>(arguments->alpha),
|
||||
*static_cast<ElementCompute const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
static_cast<ElementCompute const *>(arguments->alpha),
|
||||
static_cast<ElementCompute const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
virtual OperationDescription const & description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
/// Returns success if the operation can proceed
|
||||
virtual Status can_implement(
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr) const {
|
||||
|
||||
GemmArrayConfiguration const *configuration =
|
||||
static_cast<GemmArrayConfiguration const *>(configuration_ptr);
|
||||
|
||||
GemmArrayArguments const *arguments =
|
||||
static_cast<GemmArrayArguments const *>(arguments_ptr);
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(args, configuration);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = update_arguments_(args, arguments);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
virtual uint64_t get_host_workspace_size(
|
||||
void const *configuration) const {
|
||||
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmArrayConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return Operator::get_workspace_size(args);
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
virtual Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmArrayConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
|
||||
return op->initialize(args, device_workspace, stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel
|
||||
virtual Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = update_arguments_(
|
||||
args,
|
||||
static_cast<GemmArrayArguments const *>(arguments_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
|
||||
status = op->update(args, device_workspace);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return op->run(stream);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class GemmPlanarComplexOperation : public GemmOperationBase<Operator_> {
|
||||
public:
|
||||
|
||||
using Operator = Operator_;
|
||||
using ElementA = typename Operator::ElementA;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::ElementB;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
|
||||
|
||||
this->description_.gemm_kind = GemmKind::kPlanarComplex;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmPlanarComplexConfiguration const *configuration) {
|
||||
|
||||
operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched;
|
||||
operator_args.problem_size = configuration->problem_size;
|
||||
operator_args.batch_count = configuration->batch_count;
|
||||
|
||||
operator_args.lda_real = int(configuration->lda_real);
|
||||
operator_args.lda_imag = int(configuration->lda_imag);
|
||||
operator_args.ldb_real = int(configuration->ldb_real);
|
||||
operator_args.ldb_imag = int(configuration->ldb_imag);
|
||||
operator_args.ldc_real = int(configuration->ldc_real);
|
||||
operator_args.ldc_imag = int(configuration->ldc_imag);
|
||||
operator_args.ldd_real = int(configuration->ldd_real);
|
||||
operator_args.ldd_imag = int(configuration->ldd_imag);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmPlanarComplexArguments const *arguments) {
|
||||
|
||||
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
|
||||
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
|
||||
static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// update arguments
|
||||
operator_args.ptr_A_real = arguments->A_real;
|
||||
operator_args.ptr_A_imag = arguments->A_imag;
|
||||
operator_args.ptr_B_real = arguments->B_real;
|
||||
operator_args.ptr_B_imag = arguments->B_imag;
|
||||
operator_args.ptr_C_real = arguments->C_real;
|
||||
operator_args.ptr_C_imag = arguments->C_imag;
|
||||
operator_args.ptr_D_real = arguments->D_real;
|
||||
operator_args.ptr_D_imag = arguments->D_imag;
|
||||
|
||||
operator_args.batch_stride_A = arguments->batch_stride_A_real;
|
||||
operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag;
|
||||
operator_args.batch_stride_B = arguments->batch_stride_B_real;
|
||||
operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag;
|
||||
operator_args.batch_stride_C = arguments->batch_stride_C_real;
|
||||
operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag;
|
||||
operator_args.batch_stride_D = arguments->batch_stride_D_real;
|
||||
operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns success if the operation can proceed
|
||||
virtual Status can_implement(
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr) const {
|
||||
|
||||
GemmPlanarComplexConfiguration const *configuration =
|
||||
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr);
|
||||
|
||||
GemmPlanarComplexArguments const *arguments =
|
||||
static_cast<GemmPlanarComplexArguments const *>(arguments_ptr);
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(args, configuration);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = update_arguments_(args, arguments);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
virtual uint64_t get_host_workspace_size(
|
||||
void const *configuration) const {
|
||||
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
virtual Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
|
||||
status = op->initialize(args, device_workspace, stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Runs the kernel
|
||||
virtual Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = update_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexArguments const *>(arguments_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
|
||||
status = op->update(args, device_workspace);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = op->run(stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class GemmPlanarComplexArrayOperation : public GemmOperationBase<Operator_> {
|
||||
public:
|
||||
|
||||
using Operator = Operator_;
|
||||
using ElementA = typename Operator::ElementA;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::ElementB;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase<Operator_>(name) {
|
||||
|
||||
this->description_.gemm_kind = GemmKind::kPlanarComplexArray;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmPlanarComplexArrayConfiguration const *configuration) {
|
||||
|
||||
operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray;
|
||||
operator_args.problem_size = configuration->problem_size;
|
||||
operator_args.batch_count = configuration->batch_count;
|
||||
|
||||
operator_args.lda_real = int(configuration->lda_real);
|
||||
operator_args.lda_imag = int(configuration->lda_imag);
|
||||
operator_args.ldb_real = int(configuration->ldb_real);
|
||||
operator_args.ldb_imag = int(configuration->ldb_imag);
|
||||
operator_args.ldc_real = int(configuration->ldc_real);
|
||||
operator_args.ldc_imag = int(configuration->ldc_imag);
|
||||
operator_args.ldd_real = int(configuration->ldd_real);
|
||||
operator_args.ldd_imag = int(configuration->ldd_imag);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmPlanarComplexArrayArguments const *arguments) {
|
||||
|
||||
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
|
||||
*static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else if (arguments->pointer_mode == ScalarPointerMode::kDevice){
|
||||
typename Operator::EpilogueOutputOp::Params params(
|
||||
static_cast<cutlass::complex<ElementCompute> const *>(arguments->alpha),
|
||||
static_cast<cutlass::complex<ElementCompute> const *>(arguments->beta)
|
||||
);
|
||||
operator_args.epilogue = params;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// update arguments
|
||||
operator_args.ptr_A_real = arguments->A_real;
|
||||
operator_args.ptr_A_imag = arguments->A_imag;
|
||||
operator_args.ptr_B_real = arguments->B_real;
|
||||
operator_args.ptr_B_imag = arguments->B_imag;
|
||||
operator_args.ptr_C_real = arguments->C_real;
|
||||
operator_args.ptr_C_imag = arguments->C_imag;
|
||||
operator_args.ptr_D_real = arguments->D_real;
|
||||
operator_args.ptr_D_imag = arguments->D_imag;
|
||||
|
||||
operator_args.ptr_M = arguments->M;
|
||||
operator_args.ptr_N = arguments->N;
|
||||
operator_args.ptr_K = arguments->K;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns success if the operation can proceed
|
||||
virtual Status can_implement(
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr) const {
|
||||
|
||||
GemmPlanarComplexArrayConfiguration const *configuration =
|
||||
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr);
|
||||
|
||||
GemmPlanarComplexArrayArguments const *arguments =
|
||||
static_cast<GemmPlanarComplexArrayArguments const *>(arguments_ptr);
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(args, configuration);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = update_arguments_(args, arguments);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
virtual uint64_t get_host_workspace_size(
|
||||
void const *configuration) const {
|
||||
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
virtual Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = construct_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexArrayConfiguration const *>(configuration_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
|
||||
status = op->initialize(args, device_workspace, stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Runs the kernel
|
||||
virtual Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
Status status = update_arguments_(
|
||||
args,
|
||||
static_cast<GemmPlanarComplexArrayArguments const *>(arguments_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
|
||||
status = op->update(args, device_workspace);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = op->run(stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
|
||||
845
tools/library/src/handle.cu
Normal file
845
tools/library/src/handle.cu
Normal file
@ -0,0 +1,845 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief CUTLASS Library handle.
|
||||
*/
|
||||
|
||||
#include <stdexcept>
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/library/handle.h"
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructor
|
||||
Handle::Handle(
|
||||
cudaStream_t stream,
|
||||
size_t workspace_size
|
||||
):
|
||||
stream_(stream),
|
||||
workspace_(nullptr),
|
||||
workspace_size_(0),
|
||||
scalar_pointer_mode_(ScalarPointerMode::kHost),
|
||||
last_operation_(nullptr) {
|
||||
|
||||
int device_idx = -1;
|
||||
|
||||
cudaError_t error = cudaGetDevice(&device_idx);
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("cudaGetDevice() failed");
|
||||
}
|
||||
|
||||
error = cudaGetDeviceProperties(&device_, device_idx);
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
||||
}
|
||||
|
||||
set_workspace_size(workspace_size);
|
||||
|
||||
Singleton::get();
|
||||
}
|
||||
|
||||
/// Destructor
|
||||
Handle::~Handle() {
|
||||
if (workspace_) {
|
||||
|
||||
if (workspace_) {
|
||||
cudaFree(workspace_);
|
||||
}
|
||||
|
||||
workspace_ = nullptr;
|
||||
workspace_size_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Move constructor
|
||||
Handle::Handle(Handle && handle) {
|
||||
device_ = handle.device_;
|
||||
workspace_size_ = handle.workspace_size_;
|
||||
workspace_ = handle.workspace_;
|
||||
stream_ = handle.stream_;
|
||||
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
|
||||
|
||||
handle.workspace_ = nullptr;
|
||||
handle.workspace_size_ = 0;
|
||||
}
|
||||
|
||||
/// Move assignment operator
|
||||
Handle & Handle::operator=(Handle && handle) {
|
||||
|
||||
device_ = handle.device_;
|
||||
workspace_size_ = handle.workspace_size_;
|
||||
workspace_ = handle.workspace_;
|
||||
stream_ = handle.stream_;
|
||||
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
|
||||
|
||||
handle.workspace_ = nullptr;
|
||||
handle.workspace_size_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
int Handle::compute_capability() const {
|
||||
return device_.major * 10 + device_.minor;
|
||||
}
|
||||
|
||||
/// Sets the current CUDA stream
|
||||
void Handle::set_stream(cudaStream_t stream) {
|
||||
stream_ = stream;
|
||||
}
|
||||
|
||||
/// Gets the current CUDA stream
|
||||
cudaStream_t Handle::get_stream() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
/// Gets the device workspace size
|
||||
size_t Handle::get_workspace_size() const {
|
||||
return workspace_size_;
|
||||
}
|
||||
|
||||
/// Gets a pointer to the device workspace allocation in Global Memory
|
||||
void *Handle::get_workspace() const {
|
||||
return workspace_;
|
||||
}
|
||||
|
||||
/// Sets the size of device workspace, invalidating previous calls to get_device_workspace()
|
||||
void Handle::set_workspace_size(size_t bytes) {
|
||||
if (bytes != workspace_size_) {
|
||||
|
||||
if (workspace_) {
|
||||
cudaFree(workspace_);
|
||||
}
|
||||
|
||||
workspace_ = nullptr;
|
||||
workspace_size_ = bytes;
|
||||
|
||||
if (workspace_size_) {
|
||||
|
||||
cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_);
|
||||
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("Failed to allocate workspace");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (workspace_) {
|
||||
cudaError_t error = cudaMemset(workspace_, 0, workspace_size_);
|
||||
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("Failed to clear workspace");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the scalar pointer mode
|
||||
ScalarPointerMode Handle::get_scalar_pointer_mode() const {
|
||||
return scalar_pointer_mode_;
|
||||
}
|
||||
|
||||
/// Sets the scalar pointer mode
|
||||
void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) {
|
||||
scalar_pointer_mode_ = mode;
|
||||
}
|
||||
|
||||
/// Gets the last operation
|
||||
Operation const *Handle::get_last_operation() const {
|
||||
return last_operation_;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Returns the maximum required alignment for each operator
|
||||
static int maximum_alignment_requirement(GemmDescription const &desc) {
|
||||
return std::max(
|
||||
std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment);
|
||||
}
|
||||
|
||||
/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a
|
||||
/// given upper limit.
|
||||
static int gemm_problem_alignment(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
NumericTypeID element_A,
|
||||
void const *ptr_A,
|
||||
int lda,
|
||||
int64_t batch_stride_A,
|
||||
NumericTypeID element_B,
|
||||
void const *ptr_B,
|
||||
int ldb,
|
||||
int64_t batch_stride_B,
|
||||
NumericTypeID element_C,
|
||||
void const * ptr_C,
|
||||
int ldc,
|
||||
int64_t batch_stride_C,
|
||||
void const * ptr_D,
|
||||
int ldd,
|
||||
int64_t batch_stride_D,
|
||||
int max_alignment_in_bytes = 16
|
||||
) {
|
||||
|
||||
void const *pointers[] = {
|
||||
ptr_A, ptr_B, ptr_C, ptr_D
|
||||
};
|
||||
|
||||
int64_t extents[] = {
|
||||
M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D
|
||||
};
|
||||
|
||||
NumericTypeID elements[] = {
|
||||
element_A, element_B, element_C
|
||||
};
|
||||
|
||||
for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) {
|
||||
|
||||
bool satisfied = true;
|
||||
|
||||
// Can pointers satisfy this?
|
||||
for (void const *ptr : pointers) {
|
||||
std::uintptr_t int_ptr = reinterpret_cast<std::uintptr_t>(ptr);
|
||||
|
||||
if (int_ptr % max_alignment_in_bytes) {
|
||||
satisfied = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!satisfied) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the maximum alignment based on element data types
|
||||
int max_element_alignment = 0;
|
||||
|
||||
for (NumericTypeID type_id : elements) {
|
||||
int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id);
|
||||
max_element_alignment = std::max(max_element_alignment, element_alignment);
|
||||
}
|
||||
|
||||
// Can the problem size and leading dimensions satisfy this?
|
||||
for (int64_t extent : extents) {
|
||||
if (extent % max_element_alignment) {
|
||||
satisfied = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!satisfied) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Yes
|
||||
return max_element_alignment;
|
||||
}
|
||||
|
||||
// No alignment satisfies this problem
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Find the best kernel in descending order of preference.
|
||||
static Operation const * find_gemm_operation(
|
||||
GemmOperationFunctionalMap::const_iterator operators_it,
|
||||
GemmPreferenceKey const preference_key) {
|
||||
|
||||
auto cc_it = operators_it->second.upper_bound(preference_key);
|
||||
|
||||
if (cc_it == operators_it->second.begin()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation const *operation = nullptr;
|
||||
|
||||
// Search in descending order of compute capability
|
||||
do {
|
||||
--cc_it;
|
||||
|
||||
// Search tile sizes in order, for now.
|
||||
for (auto const * op : cc_it->second) {
|
||||
|
||||
GemmDescription const &desc = static_cast<GemmDescription const &>(op->description());
|
||||
|
||||
int min_cc = desc.tile_description.minimum_compute_capability;
|
||||
int max_cc = desc.tile_description.maximum_compute_capability;
|
||||
|
||||
int op_alignment = maximum_alignment_requirement(desc);
|
||||
|
||||
if ((min_cc <= preference_key.compute_capability) &&
|
||||
(preference_key.compute_capability <= max_cc) &&
|
||||
(op_alignment <= preference_key.alignment)) {
|
||||
|
||||
operation = op;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} while (!operation && cc_it != operators_it->second.begin());
|
||||
|
||||
return operation;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Executes a GEMM computation: D <= alpha * A*B + beta * C
|
||||
Status Handle::gemm(
|
||||
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
|
||||
|
||||
void const * ptr_A, /// Pointer to A matrix in Global Memory
|
||||
int lda, /// Leading dimension of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
|
||||
|
||||
void const * ptr_B, /// Pointer to B matrix in Global Memory
|
||||
int ldb, /// Leading dimension of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrices
|
||||
|
||||
void const * ptr_C, /// Pointer to C matrix
|
||||
int ldc, /// Leading dimension of C matrix
|
||||
|
||||
void * ptr_D, /// Pointer to D matrix
|
||||
int ldd /// Leading dimension of D matrix
|
||||
) {
|
||||
|
||||
//
|
||||
// Find the operation
|
||||
//
|
||||
|
||||
GemmFunctionalKey key(
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C
|
||||
);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
||||
|
||||
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the largest alignment restriction the kernel can satisfy.
|
||||
//
|
||||
|
||||
// Maximum alignment expectation among all kernels (in units of bytes)
|
||||
int const kMaximumAlignmentSize = 16;
|
||||
|
||||
int alignment = gemm_problem_alignment(
|
||||
M, N, K,
|
||||
element_A, ptr_A, lda, 0,
|
||||
element_B, ptr_B, ldb, 0,
|
||||
element_C, ptr_C, ldc, 0,
|
||||
ptr_D, ldd, 0, kMaximumAlignmentSize
|
||||
);
|
||||
|
||||
//
|
||||
// Find the best kernel in descending order of preference.
|
||||
//
|
||||
|
||||
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
||||
|
||||
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
||||
|
||||
if (!operation) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
last_operation_ = operation;
|
||||
|
||||
//
|
||||
// Configure operation
|
||||
//
|
||||
|
||||
GemmConfiguration configuration{
|
||||
{M, N, K},
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
1
|
||||
};
|
||||
|
||||
// Query host work space size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// Query device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(
|
||||
&configuration,
|
||||
host_workspace,
|
||||
workspace_,
|
||||
stream_);
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
GemmArguments arguments{
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
alpha,
|
||||
beta,
|
||||
scalar_pointer_mode_
|
||||
};
|
||||
|
||||
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Planar complex GEMM
|
||||
Status Handle::gemm_planar_complex(
|
||||
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
||||
|
||||
void const * ptr_A_real, /// Pointer to real part of A matrix
|
||||
void const * ptr_A_imag, /// Pointer to imaginary part of A matrix
|
||||
int lda_real, /// Leading dimension of real part of A matrix
|
||||
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
||||
|
||||
void const * ptr_B_real, /// Pointer to real part of B matrix
|
||||
void const * ptr_B_imag, /// Pointer to imaginary part of B matrix
|
||||
int ldb_real, /// Leading dimension of real part of B matrix
|
||||
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrix
|
||||
|
||||
void const * ptr_C_real, /// Pointer to real part of C matrix
|
||||
void const * ptr_C_imag, /// Pointer to imaginary part of C matrix
|
||||
int ldc_real, /// Leading dimension of real part of C matrix
|
||||
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
||||
|
||||
void * ptr_D_real, /// Pointer to real part of D matrix
|
||||
void * ptr_D_imag, /// Pointer to imaginary part of D matrix
|
||||
int ldd_real, /// Leading dimension of real part of D matrix
|
||||
int ldd_imag, /// Leading dimension of imaginary part of D matrix
|
||||
|
||||
int batch_count, /// Number of batched GEMMs to execute
|
||||
|
||||
int64_t batch_stride_A_real,
|
||||
int64_t batch_stride_A_imag,
|
||||
|
||||
int64_t batch_stride_B_real,
|
||||
int64_t batch_stride_B_imag,
|
||||
|
||||
int64_t batch_stride_C_real,
|
||||
int64_t batch_stride_C_imag,
|
||||
|
||||
int64_t batch_stride_D_real,
|
||||
int64_t batch_stride_D_imag
|
||||
) {
|
||||
|
||||
//
|
||||
// Find the operation
|
||||
//
|
||||
|
||||
GemmFunctionalKey key(
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C
|
||||
);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_planar_complex_operations.find(key);
|
||||
|
||||
if (operators_it == Singleton::get().operation_table.gemm_planar_complex_operations.end()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the largest alignment restriction the kernel can satisfy.
|
||||
//
|
||||
|
||||
// Maximum alignment expectation among all kernels (in units of bytes)
|
||||
int const kMaximumAlignmentSize = 16;
|
||||
|
||||
int alignment = std::max(
|
||||
gemm_problem_alignment(
|
||||
M, N, K,
|
||||
element_A, ptr_A_real, lda_real, batch_stride_A_real,
|
||||
element_B, ptr_B_real, ldb_real, batch_stride_B_real,
|
||||
element_C, ptr_C_real, ldc_real, batch_stride_C_real,
|
||||
ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize
|
||||
),
|
||||
gemm_problem_alignment(
|
||||
M, N, K,
|
||||
element_A, ptr_A_imag, lda_imag, batch_stride_A_imag,
|
||||
element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag,
|
||||
element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag,
|
||||
ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize
|
||||
)
|
||||
);
|
||||
|
||||
//
|
||||
// Find the best kernel in descending order of preference.
|
||||
//
|
||||
|
||||
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
||||
|
||||
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
||||
|
||||
if (!operation) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
last_operation_ = operation;
|
||||
|
||||
//
|
||||
// Configure operation
|
||||
//
|
||||
|
||||
GemmPlanarComplexConfiguration configuration{
|
||||
GemmUniversalMode::kBatched,
|
||||
{M, N, K},
|
||||
batch_count,
|
||||
lda_real,
|
||||
lda_imag,
|
||||
ldb_real,
|
||||
ldb_imag,
|
||||
ldc_real,
|
||||
ldc_imag,
|
||||
ldd_real,
|
||||
ldd_imag
|
||||
};
|
||||
|
||||
// Query host work space size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// Query device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(
|
||||
&configuration,
|
||||
host_workspace,
|
||||
workspace_,
|
||||
stream_);
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
GemmPlanarComplexArguments arguments{
|
||||
ptr_A_real,
|
||||
ptr_A_imag,
|
||||
ptr_B_real,
|
||||
ptr_B_imag,
|
||||
ptr_C_real,
|
||||
ptr_C_imag,
|
||||
ptr_D_real,
|
||||
ptr_D_imag,
|
||||
alpha,
|
||||
beta,
|
||||
scalar_pointer_mode_,
|
||||
batch_stride_A_real,
|
||||
batch_stride_A_imag,
|
||||
batch_stride_B_real,
|
||||
batch_stride_B_imag,
|
||||
batch_stride_C_real,
|
||||
batch_stride_C_imag,
|
||||
batch_stride_D_real,
|
||||
batch_stride_D_imag
|
||||
};
|
||||
|
||||
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Planar complex batched GEMM loading pointers from arrays in global memory
|
||||
Status Handle::gemm_planar_complex_array(
|
||||
|
||||
int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid)
|
||||
int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid)
|
||||
int expected_K, /// Expected GEMM K dimension
|
||||
int batch_count, /// Number of independent GEMM computations to execute
|
||||
|
||||
int const *M, /// Array containing the GEMM M dimension for each batch index
|
||||
int const *N, /// Array containing the GEMM N dimension for each batch index
|
||||
int const *K, /// Array containing the GEMM K dimension for each batch index
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
void const *alpha, /// Pointer to alpha scalar
|
||||
|
||||
NumericTypeID element_A, /// Data type of A matrix elements
|
||||
LayoutTypeID layout_A, /// Layout of A matrix
|
||||
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
||||
|
||||
void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices
|
||||
void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices
|
||||
|
||||
int lda_real, /// Leading dimension of real part of A matrix
|
||||
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
||||
|
||||
NumericTypeID element_B, /// Data type of B matrix elements
|
||||
LayoutTypeID layout_B, /// Layout of B matrix
|
||||
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
||||
|
||||
void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices
|
||||
void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices
|
||||
|
||||
int ldb_real, /// Leading dimension of real part of B matrix
|
||||
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
||||
|
||||
void const * beta, /// Pointer to beta scalar
|
||||
|
||||
NumericTypeID element_C, /// Data type of C and D matrix
|
||||
|
||||
void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices
|
||||
void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices
|
||||
|
||||
int ldc_real, /// Leading dimension of real part of C matrix
|
||||
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
||||
|
||||
void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices
|
||||
void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices
|
||||
|
||||
int ldd_real, /// Leading dimension of real part of D matrix
|
||||
int ldd_imag /// Leading dimension of imaginary part of D matrix
|
||||
) {
|
||||
|
||||
//
|
||||
// Find the operation
|
||||
//
|
||||
|
||||
GemmFunctionalKey key(
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C
|
||||
);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_planar_complex_array_operations.find(key);
|
||||
|
||||
if (operators_it == Singleton::get().operation_table.gemm_planar_complex_array_operations.end()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the largest alignment restriction the kernel can satisfy.
|
||||
//
|
||||
|
||||
// Maximum alignment expectation among all kernels (in units of bytes)
|
||||
int const kMaximumAlignmentSize = 16;
|
||||
|
||||
int alignment = std::max(
|
||||
gemm_problem_alignment(
|
||||
expected_M, expected_N, expected_K,
|
||||
element_A, nullptr, lda_real, 0,
|
||||
element_B, nullptr, ldb_real, 0,
|
||||
element_C, nullptr, ldc_real, 0,
|
||||
nullptr, ldd_real, 0, kMaximumAlignmentSize
|
||||
),
|
||||
gemm_problem_alignment(
|
||||
expected_M, expected_N, expected_K,
|
||||
element_A, nullptr, lda_imag, 0,
|
||||
element_B, nullptr, ldb_imag, 0,
|
||||
element_C, nullptr, ldc_imag, 0,
|
||||
nullptr, ldd_imag, 0, kMaximumAlignmentSize
|
||||
)
|
||||
);
|
||||
|
||||
//
|
||||
// Find the best kernel in descending order of preference.
|
||||
//
|
||||
|
||||
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
||||
|
||||
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
||||
|
||||
if (!operation) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
last_operation_ = operation;
|
||||
|
||||
//
|
||||
// Configure operation
|
||||
//
|
||||
|
||||
GemmPlanarComplexArrayConfiguration configuration{
|
||||
{expected_M, expected_N, expected_K},
|
||||
batch_count,
|
||||
lda_real,
|
||||
lda_imag,
|
||||
ldb_real,
|
||||
ldb_imag,
|
||||
ldc_real,
|
||||
ldc_imag,
|
||||
ldd_real,
|
||||
ldd_imag
|
||||
};
|
||||
|
||||
// Query host work space size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// Query device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
|
||||
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(
|
||||
&configuration,
|
||||
host_workspace,
|
||||
workspace_,
|
||||
stream_);
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
GemmPlanarComplexArrayArguments arguments{
|
||||
M, N, K,
|
||||
ptr_A_real,
|
||||
ptr_A_imag,
|
||||
ptr_B_real,
|
||||
ptr_B_imag,
|
||||
ptr_C_real,
|
||||
ptr_C_imag,
|
||||
ptr_D_real,
|
||||
ptr_D_imag,
|
||||
alpha,
|
||||
beta,
|
||||
scalar_pointer_mode_
|
||||
};
|
||||
|
||||
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -57,6 +57,10 @@ namespace library {
|
||||
|
||||
template <typename T> struct NumericTypeMap;
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::uint1b_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kB1;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::int4b_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kS4;
|
||||
};
|
||||
@ -123,6 +127,28 @@ template <> struct NumericTypeMap<cutlass::complex<double> > {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct MathOperationMap {
|
||||
static MathOperationID const kId = MathOperationID::kInvalid;
|
||||
};
|
||||
|
||||
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAdd> {
|
||||
static MathOperationID const kId = MathOperationID::kMultiplyAdd;
|
||||
};
|
||||
|
||||
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> {
|
||||
static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate;
|
||||
};
|
||||
|
||||
template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> {
|
||||
static MathOperationID const kId = MathOperationID::kMultiplyAddComplex;
|
||||
};
|
||||
|
||||
template <> struct MathOperationMap<cutlass::arch::OpXorPopc> {
|
||||
static MathOperationID const kId = MathOperationID::kXorPopc;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct LayoutMap;
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::ColumnMajor> {
|
||||
@ -133,6 +159,34 @@ template <> struct LayoutMap<cutlass::layout::RowMajor> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajor;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64;
|
||||
};
|
||||
|
||||
template <> struct LayoutMap<cutlass::layout::TensorNHWC> {
|
||||
static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct OpcodeClassMap;
|
||||
@ -148,6 +202,19 @@ template <> struct OpcodeClassMap<arch::OpClassTensorOp> {
|
||||
template <> struct OpcodeClassMap<arch::OpClassWmmaTensorOp> {
|
||||
static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <cutlass::ComplexTransform Transform> struct ComplexTransformMap;
|
||||
|
||||
template <> struct ComplexTransformMap<cutlass::ComplexTransform::kNone> {
|
||||
static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone;
|
||||
};
|
||||
|
||||
template <> struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> {
|
||||
static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct ArchMap;
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
/*!
|
||||
|
||||
*//***************************************************************************************************
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
@ -37,11 +35,12 @@
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_all(Manifest &manifest);
|
||||
// init and insert all cutlass op in manifest object (procedurally generated using generator.py)
|
||||
void initialize_all(Manifest &manifest);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Top-level initialization
|
||||
Status Manifest::initialize() {
|
||||
@ -50,7 +49,13 @@ Status Manifest::initialize() {
|
||||
operations_.clear();
|
||||
}
|
||||
|
||||
initialize_all(*this);
|
||||
switch(provider_) {
|
||||
case Provider::kCUTLASS:
|
||||
initialize_all(*this); break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
159
tools/library/src/operation_table.cu
Normal file
159
tools/library/src/operation_table.cu
Normal file
@ -0,0 +1,159 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
\file
|
||||
\brief Defines a data structure in which a set of functionally equivalent library::Operation
|
||||
instances may be queried.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) {
|
||||
|
||||
out << "{\n"
|
||||
<< " element_compute: " << to_string(k.element_compute) << "\n"
|
||||
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
|
||||
<< " element_A: " << to_string(k.element_A) << "\n"
|
||||
<< " layout_A: " << to_string(k.layout_A) << "\n"
|
||||
<< " transform_A: " << to_string(k.transform_A) << "\n"
|
||||
<< " element_B: " << to_string(k.element_B) << "\n"
|
||||
<< " layout_B: " << to_string(k.layout_B) << "\n"
|
||||
<< " transform_B: " << to_string(k.transform_B) << "\n"
|
||||
<< " element_C: " << to_string(k.element_C) << "\n"
|
||||
<< "}";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void OperationTable::append(Manifest const &manifest) {
|
||||
|
||||
// Insert operations into appropriate data structure
|
||||
for (auto const & operation : manifest) {
|
||||
|
||||
OperationDescription const &desc = operation->description();
|
||||
|
||||
if (desc.kind == OperationKind::kGemm) {
|
||||
GemmDescription const &gemm_desc = static_cast<GemmDescription const &>(desc);
|
||||
|
||||
if (gemm_desc.gemm_kind == GemmKind::kGemm) {
|
||||
|
||||
GemmFunctionalKey functional_key(
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
gemm_desc.C.element
|
||||
);
|
||||
|
||||
Operation const *op = operation.get();
|
||||
|
||||
int cc = gemm_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
int alignment = std::max(std::max(
|
||||
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
|
||||
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
|
||||
gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplex) {
|
||||
|
||||
GemmFunctionalKey functional_key(
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
gemm_desc.C.element
|
||||
);
|
||||
|
||||
Operation const *op = operation.get();
|
||||
|
||||
int cc = gemm_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
int alignment = std::max(std::max(
|
||||
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
|
||||
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
|
||||
gemm_planar_complex_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplexArray) {
|
||||
|
||||
GemmFunctionalKey functional_key(
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
gemm_desc.C.element
|
||||
);
|
||||
|
||||
Operation const *op = operation.get();
|
||||
|
||||
int cc = gemm_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
int alignment = std::max(std::max(
|
||||
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
|
||||
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
|
||||
gemm_planar_complex_array_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
63
tools/library/src/singleton.cu
Normal file
63
tools/library/src/singleton.cu
Normal file
@ -0,0 +1,63 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <memory>
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
#include "cutlass/library/singleton.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static std::unique_ptr<Singleton> instance;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Singleton::Singleton() {
|
||||
|
||||
manifest.initialize();
|
||||
|
||||
operation_table.append(manifest);
|
||||
}
|
||||
|
||||
Singleton const & Singleton::get() {
|
||||
if (!instance.get()) {
|
||||
instance.reset(new Singleton);
|
||||
}
|
||||
return *instance.get();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -25,17 +25,65 @@
|
||||
|
||||
#include <iosfwd>
|
||||
#include <complex>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
Provider enumerant;
|
||||
}
|
||||
Provider_enumerants[] = {
|
||||
{"cutlass", "CUTLASS", Provider::kCUTLASS},
|
||||
{"host", "reference_host", Provider::kReferenceHost},
|
||||
{"device", "reference_device", Provider::kReferenceDevice},
|
||||
{"cublas", "cuBLAS", Provider::kCUBLAS},
|
||||
};
|
||||
|
||||
/// Converts a Provider enumerant to a string
|
||||
char const *to_string(Provider provider, bool pretty) {
|
||||
|
||||
for (auto const & possible : Provider_enumerants) {
|
||||
if (provider == possible.enumerant) {
|
||||
if (pretty) {
|
||||
return possible.pretty;
|
||||
}
|
||||
else {
|
||||
return possible.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pretty ? "Invalid" : "invalid";
|
||||
}
|
||||
|
||||
/// Parses a Provider enumerant from a string
|
||||
template <>
|
||||
Provider from_string<Provider>(std::string const &str) {
|
||||
|
||||
for (auto const & possible : Provider_enumerants) {
|
||||
if ((str.compare(possible.text) == 0) ||
|
||||
(str.compare(possible.pretty) == 0)) {
|
||||
return possible.enumerant;
|
||||
}
|
||||
}
|
||||
|
||||
return Provider::kInvalid;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct {
|
||||
@ -44,7 +92,7 @@ static struct {
|
||||
OperationKind enumerant;
|
||||
}
|
||||
OperationKind_enumerants[] = {
|
||||
{"gemm", "Gemm", OperationKind::kGemm},
|
||||
{"gemm", "Gemm", OperationKind::kGemm},
|
||||
};
|
||||
|
||||
/// Converts a Status enumerant to a string
|
||||
@ -203,6 +251,9 @@ int sizeof_bits(NumericTypeID type) {
|
||||
case NumericTypeID::kF16: return 16;
|
||||
case NumericTypeID::kF32: return 32;
|
||||
case NumericTypeID::kF64: return 64;
|
||||
case NumericTypeID::kCF16: return 32;
|
||||
case NumericTypeID::kCF32: return 64;
|
||||
case NumericTypeID::kCF64: return 128;
|
||||
case NumericTypeID::kS4: return 4;
|
||||
case NumericTypeID::kS8: return 8;
|
||||
case NumericTypeID::kS16: return 16;
|
||||
@ -291,6 +342,9 @@ bool is_float_type(NumericTypeID type) {
|
||||
case NumericTypeID::kF16: return true;
|
||||
case NumericTypeID::kF32: return true;
|
||||
case NumericTypeID::kF64: return true;
|
||||
case NumericTypeID::kCF16: return true;
|
||||
case NumericTypeID::kCF32: return true;
|
||||
case NumericTypeID::kCF64: return true;
|
||||
default: break;
|
||||
}
|
||||
return false;
|
||||
@ -309,8 +363,18 @@ layout_aliases[] = {
|
||||
{LayoutTypeID::kColumnMajor, "column"},
|
||||
{LayoutTypeID::kColumnMajor, "col"},
|
||||
{LayoutTypeID::kColumnMajor, "n"},
|
||||
|
||||
{LayoutTypeID::kColumnMajorInterleavedK16, "nk16"},
|
||||
{LayoutTypeID::kRowMajorInterleavedK16, "tk16"},
|
||||
|
||||
{LayoutTypeID::kColumnMajorInterleavedK32, "nk32"},
|
||||
{LayoutTypeID::kRowMajorInterleavedK32, "tk32"},
|
||||
|
||||
{LayoutTypeID::kColumnMajorInterleavedK64, "nk64"},
|
||||
{LayoutTypeID::kRowMajorInterleavedK64, "tk64"},
|
||||
|
||||
{LayoutTypeID::kTensorNCHW, "nchw"},
|
||||
{LayoutTypeID::kTensorNHWC, "packed_nhwc"},
|
||||
{LayoutTypeID::kTensorNHWC, "nhwc"},
|
||||
{LayoutTypeID::kUnknown, "*"},
|
||||
{LayoutTypeID::kInvalid, nullptr}
|
||||
};
|
||||
@ -344,7 +408,12 @@ int get_layout_stride_rank(LayoutTypeID layout_id) {
|
||||
case LayoutTypeID::kColumnMajorInterleavedK4:
|
||||
case LayoutTypeID::kRowMajorInterleavedK4:
|
||||
case LayoutTypeID::kColumnMajorInterleavedK16:
|
||||
case LayoutTypeID::kRowMajorInterleavedK16: return 1;
|
||||
case LayoutTypeID::kRowMajorInterleavedK16:
|
||||
case LayoutTypeID::kColumnMajorInterleavedK32:
|
||||
case LayoutTypeID::kRowMajorInterleavedK32:
|
||||
case LayoutTypeID::kColumnMajorInterleavedK64:
|
||||
case LayoutTypeID::kRowMajorInterleavedK64:
|
||||
return 1;
|
||||
case LayoutTypeID::kTensorNCHW:
|
||||
case LayoutTypeID::kTensorNHWC: return 3;
|
||||
default : throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank");
|
||||
@ -396,8 +465,51 @@ OpcodeClassID from_string<OpcodeClassID>(std::string const &str) {
|
||||
return OpcodeClassID::kInvalid;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
ComplexTransform enumerant;
|
||||
}
|
||||
ComplexTransform_enumerants[] = {
|
||||
{"n", "none", ComplexTransform::kNone},
|
||||
{"c", "conj", ComplexTransform::kConjugate}
|
||||
};
|
||||
|
||||
/// Converts a ComplexTransform enumerant to a string
|
||||
char const *to_string(ComplexTransform type, bool pretty) {
|
||||
|
||||
for (auto const & possible : ComplexTransform_enumerants) {
|
||||
if (type == possible.enumerant) {
|
||||
if (pretty) {
|
||||
return possible.pretty;
|
||||
}
|
||||
else {
|
||||
return possible.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pretty ? "Invalid" : "invalid";
|
||||
}
|
||||
|
||||
/// Converts a ComplexTransform enumerant from a string
|
||||
template <>
|
||||
ComplexTransform from_string<ComplexTransform>(std::string const &str) {
|
||||
|
||||
for (auto const & possible : ComplexTransform_enumerants) {
|
||||
if ((str.compare(possible.text) == 0) ||
|
||||
(str.compare(possible.pretty) == 0)) {
|
||||
return possible.enumerant;
|
||||
}
|
||||
}
|
||||
|
||||
return ComplexTransform::kInvalid;
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
|
||||
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str) {
|
||||
int size_bytes = sizeof_bits(type) / 8;
|
||||
@ -574,25 +686,36 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
||||
break;
|
||||
case NumericTypeID::kCF16:
|
||||
{
|
||||
std::complex<float> tmp;
|
||||
|
||||
cutlass::complex<half_t> const *x =
|
||||
reinterpret_cast<cutlass::complex<half_t> const *>(bytes.data());
|
||||
|
||||
tmp.real(x->real());
|
||||
tmp.imag(x->imag());
|
||||
ss << float(x->real());
|
||||
|
||||
ss << tmp;
|
||||
if (x->imag() != cutlass::half_t()) {
|
||||
ss << "+i" << float(x->imag());
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kCF32:
|
||||
{
|
||||
ss << *reinterpret_cast<std::complex<float>*>(bytes.data());
|
||||
cutlass::complex<float> const * x = reinterpret_cast<cutlass::complex<float> const *>(bytes.data());
|
||||
|
||||
ss << x->real();
|
||||
|
||||
if (x->imag() != float()) {
|
||||
ss << "+i" << x->imag();
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kCF64:
|
||||
{
|
||||
ss << *reinterpret_cast<std::complex<double>*>(bytes.data());
|
||||
cutlass::complex<double> const * x = reinterpret_cast<cutlass::complex<double> const *>(bytes.data());
|
||||
|
||||
ss << x->real();
|
||||
|
||||
if (x->imag() != double()) {
|
||||
ss << "+i" << x->imag();
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -33,7 +33,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
|
||||
src/gpu_timer.cpp
|
||||
src/device_allocation.cu
|
||||
src/device_context.cu
|
||||
src/cublas_helpers.cpp
|
||||
src/cublas_helpers.cpp
|
||||
src/problem_space.cpp
|
||||
src/operation_profiler.cu
|
||||
src/gemm_operation_profiler.cu
|
||||
@ -54,11 +54,11 @@ set_target_properties(cutlass_profiler PROPERTIES EXPORT_NAME profiler)
|
||||
# Include paths
|
||||
#
|
||||
|
||||
target_include_directories(cutlass_profiler
|
||||
target_include_directories(
|
||||
cutlass_profiler
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src # Source directory
|
||||
../../tools/util/include
|
||||
)
|
||||
)
|
||||
|
||||
#
|
||||
# Library dependencies
|
||||
@ -68,8 +68,8 @@ target_link_libraries(
|
||||
cutlass_profiler
|
||||
PRIVATE
|
||||
cutlass_lib
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:cublas>
|
||||
gtest
|
||||
cutlass_tools_util_includes
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
|
||||
cudart
|
||||
)
|
||||
|
||||
|
||||
@ -39,14 +39,14 @@ namespace profiler {
|
||||
/// Converts a cuBLAS status to cutlass::Status
|
||||
Status get_cutlass_status(cublasStatus_t cublas) {
|
||||
|
||||
if (cublas == CUBLAS_STATUS_SUCCESS) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else if (cublas == CUBLAS_STATUS_INVALID_VALUE) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
if (cublas == CUBLAS_STATUS_NOT_SUPPORTED) {
|
||||
return Status::kErrorNotSupported;
|
||||
switch (cublas) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return Status::kSuccess;
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return Status::kErrorInvalidProblem;
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return Status::kErrorNotSupported;
|
||||
default: break;
|
||||
}
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
@ -145,6 +145,13 @@ Status cublas_satisfies(library::GemmDescription const &desc) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// output type S4 and S8 not supported in cuBLAS
|
||||
if (desc.C.element == library::NumericTypeID::kS4 ||
|
||||
desc.C.element == library::NumericTypeID::kS8) {
|
||||
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
#include "options.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -86,6 +86,161 @@ public:
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Selects one or more cuBLAS algorithms.
|
||||
static void select_cublas_algorithms(
|
||||
std::vector<cublasGemmAlgo_t> &algorithms,
|
||||
Options const &options,
|
||||
library::GemmDescription const &op_desc) {
|
||||
|
||||
library::OpcodeClassID const & opcode_class =
|
||||
op_desc.tile_description.math_instruction.opcode_class;
|
||||
|
||||
switch (options.library.algorithm_mode) {
|
||||
case AlgorithmMode::kMatching:
|
||||
{
|
||||
algorithms.push_back(get_cublas_gemm_algo(
|
||||
op_desc.tile_description.threadblock_shape.m(),
|
||||
op_desc.tile_description.threadblock_shape.n(),
|
||||
op_desc.tile_description.threadblock_shape.k(),
|
||||
opcode_class));
|
||||
break;
|
||||
}
|
||||
|
||||
case AlgorithmMode::kBest:
|
||||
{
|
||||
// Choose first enumerated mode. If none are enumerated, choose based on opcode class
|
||||
// and evaluate all of them.
|
||||
|
||||
if (options.library.algorithms.empty()) {
|
||||
// Enumerate all algorithms
|
||||
if (opcode_class == library::OpcodeClassID::kSimt) {
|
||||
|
||||
for (int algo = CUBLAS_GEMM_DEFAULT;
|
||||
algo <= CUBLAS_GEMM_ALGO23;
|
||||
++algo) {
|
||||
|
||||
algorithms.push_back(cublasGemmAlgo_t(algo));
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP;
|
||||
++algo) {
|
||||
|
||||
algorithms.push_back(cublasGemmAlgo_t(algo));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Use the listed algorithms
|
||||
algorithms.reserve(options.library.algorithms.size());
|
||||
|
||||
for (int algo : options.library.algorithms) {
|
||||
algorithms.push_back(reinterpret_cast<cublasGemmAlgo_t const &>(algo));
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case AlgorithmMode::kDefault:
|
||||
{
|
||||
|
||||
// Use the library's default algorithm
|
||||
algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ?
|
||||
CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatcher to cublasGemmEx()
|
||||
struct cublasGemmExDispatcher {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
library::GemmConfiguration configuration;
|
||||
library::GemmArguments arguments;
|
||||
|
||||
// cublass-specific data structures to fill cublas API call arguments
|
||||
cublasOperation_t trans_A;
|
||||
cublasOperation_t trans_B;
|
||||
cudaDataType_t data_type_A;
|
||||
cudaDataType_t data_type_B;
|
||||
cudaDataType_t data_type_C;
|
||||
cudaDataType_t compute_type;
|
||||
cublasGemmAlgo_t algo;
|
||||
Status status;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
cublasGemmExDispatcher(
|
||||
library::GemmDescription const &op_desc,
|
||||
library::GemmConfiguration configuration_,
|
||||
library::GemmArguments arguments_,
|
||||
cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT
|
||||
):
|
||||
configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) {
|
||||
|
||||
trans_A = get_cublas_transpose_operation(op_desc.A.layout);
|
||||
trans_B = get_cublas_transpose_operation(op_desc.B.layout);
|
||||
|
||||
bool good = true;
|
||||
good = (good && get_cublas_datatype(data_type_A, op_desc.A.element));
|
||||
good = (good && get_cublas_datatype(data_type_B, op_desc.B.element));
|
||||
good = (good && get_cublas_datatype(data_type_C, op_desc.C.element));
|
||||
|
||||
good = (good && get_cublas_datatype(
|
||||
compute_type,
|
||||
op_desc.tile_description.math_instruction.element_accumulator));
|
||||
|
||||
if (!good) {
|
||||
status = Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes GEMM using these arguments
|
||||
cublasStatus_t operator()(cublasHandle_t handle) {
|
||||
|
||||
return cublasGemmEx(
|
||||
handle,
|
||||
trans_A,
|
||||
trans_B,
|
||||
configuration.problem_size.m(),
|
||||
configuration.problem_size.n(),
|
||||
configuration.problem_size.k(),
|
||||
arguments.alpha,
|
||||
arguments.A,
|
||||
data_type_A,
|
||||
int(configuration.lda),
|
||||
arguments.B,
|
||||
data_type_B,
|
||||
int(configuration.ldb),
|
||||
arguments.beta,
|
||||
arguments.D,
|
||||
data_type_C,
|
||||
int(configuration.ldc),
|
||||
compute_type,
|
||||
algo
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -29,14 +29,9 @@
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
// Profiler includes
|
||||
#include "cutlass_profiler.h"
|
||||
#include "gemm_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -49,7 +44,8 @@ CutlassProfiler::CutlassProfiler(
|
||||
):
|
||||
options_(options) {
|
||||
|
||||
operation_profilers_.emplace_back(new GemmOperationProfiler);
|
||||
operation_profilers_.emplace_back(new GemmOperationProfiler);
|
||||
|
||||
}
|
||||
|
||||
CutlassProfiler::~CutlassProfiler() {
|
||||
@ -112,7 +108,7 @@ void CutlassProfiler::enumerate_() {
|
||||
/// Profiles all operations
|
||||
int CutlassProfiler::profile_() {
|
||||
|
||||
library::Manifest manifest;
|
||||
library::Manifest manifest(library::Provider::kCUTLASS);
|
||||
Status status = manifest.initialize();
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
@ -165,7 +161,8 @@ void CutlassProfiler::print_usage_(std::ostream &out) {
|
||||
}
|
||||
|
||||
out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --help\n\n";
|
||||
<< " $ cutlass_profiler --operation=Gemm --help\n\n"
|
||||
;
|
||||
}
|
||||
|
||||
/// Prints usage
|
||||
|
||||
@ -27,6 +27,9 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "options.h"
|
||||
#include "operation_profiler.h"
|
||||
|
||||
@ -30,11 +30,11 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; }
|
||||
//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; }
|
||||
//#define report(x) {}
|
||||
|
||||
// Enable/Disble Profiler debug prints
|
||||
#define DEBUG_PROFILER
|
||||
//#define DEBUG_PROFILER
|
||||
|
||||
//RED 31m // profiler prints debug messages in red
|
||||
//YELLOW 33m // ir prints debug messages in yellow
|
||||
@ -43,7 +43,7 @@
|
||||
#define debugprof(...)
|
||||
#else
|
||||
#define debugprof(...) do { \
|
||||
printf("\033[31m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \
|
||||
printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\033[0m\n"); \
|
||||
} while (0)
|
||||
|
||||
@ -34,12 +34,12 @@
|
||||
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
#include "device_allocation.h"
|
||||
|
||||
namespace cutlass {
|
||||
@ -106,6 +106,18 @@ std::vector<int> DeviceAllocation::get_packed_layout(
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK16:
|
||||
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<16>>(extent);
|
||||
break;
|
||||
case library::LayoutTypeID::kColumnMajorInterleavedK32:
|
||||
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<32>>(extent);
|
||||
break;
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK32:
|
||||
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<32>>(extent);
|
||||
break;
|
||||
case library::LayoutTypeID::kColumnMajorInterleavedK64:
|
||||
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<64>>(extent);
|
||||
break;
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK64:
|
||||
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<64>>(extent);
|
||||
break;
|
||||
case library::LayoutTypeID::kTensorNCHW:
|
||||
stride = get_packed_layout_stride<cutlass::layout::TensorNCHW>(extent);
|
||||
break;
|
||||
@ -200,6 +212,18 @@ size_t DeviceAllocation::construct_layout(
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK16:
|
||||
return construct_layout_<cutlass::layout::RowMajorInterleaved<16>>(bytes, layout_id, extent, stride);
|
||||
|
||||
case library::LayoutTypeID::kColumnMajorInterleavedK32:
|
||||
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<32>>(bytes, layout_id, extent, stride);
|
||||
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK32:
|
||||
return construct_layout_<cutlass::layout::RowMajorInterleaved<32>>(bytes, layout_id, extent, stride);
|
||||
|
||||
case library::LayoutTypeID::kColumnMajorInterleavedK64:
|
||||
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<64>>(bytes, layout_id, extent, stride);
|
||||
|
||||
case library::LayoutTypeID::kRowMajorInterleavedK64:
|
||||
return construct_layout_<cutlass::layout::RowMajorInterleaved<64>>(bytes, layout_id, extent, stride);
|
||||
|
||||
case library::LayoutTypeID::kTensorNCHW:
|
||||
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
|
||||
|
||||
@ -415,6 +439,14 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kCF64:
|
||||
cutlass::reference::device::BlockFillRandom<complex<double>>(
|
||||
reinterpret_cast<complex<double> *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kS8:
|
||||
cutlass::reference::device::BlockFillRandom<int8_t>(
|
||||
reinterpret_cast<int8_t *>(pointer_),
|
||||
@ -508,6 +540,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kCF16:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::half_t>>(
|
||||
reinterpret_cast<cutlass::complex<cutlass::half_t> *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kF64:
|
||||
cutlass::reference::host::BlockFillRandom<double>(
|
||||
reinterpret_cast<double *>(host_data.data()),
|
||||
@ -516,6 +556,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kCF64:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::complex<double>>(
|
||||
reinterpret_cast<cutlass::complex<double> *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kS8:
|
||||
cutlass::reference::host::BlockFillRandom<int8_t>(
|
||||
reinterpret_cast<int8_t *>(host_data.data()),
|
||||
@ -607,13 +655,25 @@ bool DeviceAllocation::block_compare_equal(
|
||||
reinterpret_cast<float const *>(ptr_A),
|
||||
reinterpret_cast<float const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
|
||||
case library::NumericTypeID::kCF16:
|
||||
return reference::device::BlockCompareEqual<complex<half_t>>(
|
||||
reinterpret_cast<complex<half_t> const *>(ptr_A),
|
||||
reinterpret_cast<complex<half_t> const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kF64:
|
||||
return reference::device::BlockCompareEqual<double>(
|
||||
reinterpret_cast<double const *>(ptr_A),
|
||||
reinterpret_cast<double const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kCF64:
|
||||
return reference::device::BlockCompareEqual<complex<double>>(
|
||||
reinterpret_cast<complex<double> const *>(ptr_A),
|
||||
reinterpret_cast<complex<double> const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kS8:
|
||||
return reference::device::BlockCompareEqual<int8_t>(
|
||||
reinterpret_cast<int8_t const *>(ptr_A),
|
||||
|
||||
@ -74,16 +74,34 @@ DeviceAllocation *DeviceContext::allocate_tensor(
|
||||
allocate_tensor(name, type, layout_id, extent, stride);
|
||||
|
||||
if (options.initialization.enabled) {
|
||||
Distribution data_distribution = options.initialization.data_distribution;
|
||||
|
||||
if (options.initialization.provider == Provider::kReferenceDevice) {
|
||||
// check if data distribution is allowed to change
|
||||
if(!options.initialization.fix_data_distribution) {
|
||||
// change data distribution based on bit width
|
||||
switch(type) {
|
||||
case library::NumericTypeID::kB1:
|
||||
data_distribution.set_uniform(0, 2, 0);
|
||||
break;
|
||||
case library::NumericTypeID::kS8:
|
||||
data_distribution.set_uniform(-2, 2, 0);
|
||||
break;
|
||||
case library::NumericTypeID::kU8:
|
||||
data_distribution.set_uniform(0, 4, 0);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
if (options.initialization.provider == library::Provider::kReferenceDevice) {
|
||||
allocation->initialize_random_device(
|
||||
options.initialization.seed,
|
||||
options.initialization.data_distribution);
|
||||
data_distribution);
|
||||
}
|
||||
else if (options.initialization.provider == Provider::kReferenceHost) {
|
||||
else if (options.initialization.provider == library::Provider::kReferenceHost) {
|
||||
allocation->initialize_random_host(
|
||||
options.initialization.seed,
|
||||
options.initialization.data_distribution);
|
||||
data_distribution);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -123,53 +123,6 @@ AlgorithmMode from_string<AlgorithmMode>(std::string const &str) {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
Provider enumerant;
|
||||
}
|
||||
Provider_enumerants[] = {
|
||||
{"cutlass", "CUTLASS", Provider::kCUTLASS},
|
||||
{"host", "reference_host", Provider::kReferenceHost},
|
||||
{"device", "reference_device", Provider::kReferenceDevice},
|
||||
{"cublas", "cuBLAS", Provider::kCUBLAS},
|
||||
};
|
||||
|
||||
/// Converts a Provider enumerant to a string
|
||||
char const *to_string(Provider provider, bool pretty) {
|
||||
|
||||
for (auto const & possible : Provider_enumerants) {
|
||||
if (provider == possible.enumerant) {
|
||||
if (pretty) {
|
||||
return possible.pretty;
|
||||
}
|
||||
else {
|
||||
return possible.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pretty ? "Invalid" : "invalid";
|
||||
}
|
||||
|
||||
/// Parses a Provider enumerant from a string
|
||||
template <>
|
||||
Provider from_string<Provider>(std::string const &str) {
|
||||
|
||||
for (auto const & possible : Provider_enumerants) {
|
||||
if ((str.compare(possible.text) == 0) ||
|
||||
(str.compare(possible.pretty) == 0)) {
|
||||
return possible.enumerant;
|
||||
}
|
||||
}
|
||||
|
||||
return Provider::kInvalid;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
@ -180,6 +133,7 @@ Disposition_enumerants[] = {
|
||||
{"failed", "Failed", Disposition::kFailed},
|
||||
{"not_run", "Not run", Disposition::kNotRun},
|
||||
{"not_verified", "Not verified", Disposition::kNotVerified},
|
||||
{"invalid_problem", "Invalid problem", Disposition::kInvalidProblem},
|
||||
{"not_supported", "Not supported", Disposition::kNotSupported},
|
||||
{"incorrect", "Incorrect", Disposition::kIncorrect}
|
||||
};
|
||||
|
||||
@ -30,7 +30,9 @@
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; }
|
||||
|
||||
@ -79,26 +81,6 @@ AlgorithmMode from_string<AlgorithmMode>(std::string const &str);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Providers
|
||||
enum class Provider {
|
||||
kCUTLASS,
|
||||
kReferenceHost,
|
||||
kReferenceDevice,
|
||||
kCUBLAS,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
using ProviderVector = std::vector<Provider>;
|
||||
|
||||
/// Converts a Provider enumerant to a string
|
||||
char const *to_string(Provider provider, bool pretty = false);
|
||||
|
||||
/// Parses a Provider enumerant from a string
|
||||
template <>
|
||||
Provider from_string<Provider>(std::string const &str);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Outcome of a performance test
|
||||
enum class Disposition {
|
||||
kPassed,
|
||||
@ -106,12 +88,13 @@ enum class Disposition {
|
||||
kNotRun,
|
||||
kIncorrect,
|
||||
kNotVerified,
|
||||
kInvalidProblem,
|
||||
kNotSupported,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
/// Converts a Disposition enumerant to a string
|
||||
char const *to_string(Disposition provider, bool pretty = false);
|
||||
char const *to_string(Disposition disposition, bool pretty = false);
|
||||
|
||||
/// Parses a Disposition enumerant from a string
|
||||
template <>
|
||||
@ -159,6 +142,21 @@ char const *to_string(ArgumentTypeID type, bool pretty = false);
|
||||
template <>
|
||||
ArgumentTypeID from_string<ArgumentTypeID>(std::string const &str);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Profiler typedefs
|
||||
using ProviderVector = std::vector<library::Provider>;
|
||||
using DispositionMap = std::map<library::Provider, Disposition>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Print vector for the report
|
||||
template <typename T>
|
||||
std::ostream& operator<< (std::ostream& out, const std::vector<T>& v) {
|
||||
for(int i = 0; i < v.size(); ++i) {
|
||||
out << to_string(v[i], true) << (i+1 != v.size() ? "," : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@ -140,7 +140,6 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_int(problem_.m, "m", problem_space, problem)) {
|
||||
// default value
|
||||
problem_.m = 1024;
|
||||
@ -201,7 +200,7 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
problem_.lda = DeviceAllocation::get_packed_layout(
|
||||
operation_desc.A.layout, {int(problem_.m), int(problem_.k)}).front();
|
||||
|
||||
@ -240,7 +239,7 @@ void GemmOperationProfiler::initialize_result_(
|
||||
library::GemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space) {
|
||||
|
||||
result.provider = Provider::kCUTLASS;
|
||||
result.provider = library::Provider::kCUTLASS;
|
||||
result.disposition = Disposition::kNotRun;
|
||||
result.status = Status::kSuccess;
|
||||
result.operation_name = operation_desc.name;
|
||||
@ -277,9 +276,17 @@ void GemmOperationProfiler::initialize_result_(
|
||||
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n * 2;
|
||||
|
||||
result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n);
|
||||
|
||||
result.runtime = 0;
|
||||
|
||||
// complex-valued support
|
||||
switch (operation_desc.tile_description.math_instruction.math_operation) {
|
||||
case library::MathOperationID::kMultiplyAddComplex:
|
||||
result.flops *= 4;
|
||||
break;
|
||||
|
||||
default: break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Initializes workspace
|
||||
@ -290,7 +297,7 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
|
||||
library::GemmDescription const &operation_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
@ -348,7 +355,7 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
//
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
if (options.profiling.provider_enabled(Provider::kCUTLASS)) {
|
||||
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
||||
|
||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||
|
||||
@ -368,8 +375,12 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
// If CUTLASS is enabled, generate a result for it
|
||||
//
|
||||
results_.push_back(model_result_);
|
||||
results_.back().provider = Provider::kCUTLASS;
|
||||
results_.back().provider = library::Provider::kCUTLASS;
|
||||
results_.back().op_kind = library::OperationKind::kGemm;
|
||||
results_.back().disposition = Disposition::kNotRun;
|
||||
for(auto &verification_provider : options.verification.providers) {
|
||||
results_.back().verification_map[verification_provider] = Disposition::kNotRun;
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
@ -386,7 +397,7 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
if (!options.profiling.provider_enabled(Provider::kCUTLASS)) {
|
||||
if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -423,198 +434,62 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
return false;
|
||||
}
|
||||
|
||||
// CUTLASS op ran the but not yet verified against any verification provider
|
||||
results_.back().disposition = Disposition::kNotVerified;
|
||||
|
||||
//
|
||||
// Run verification providers
|
||||
//
|
||||
|
||||
if (options.verification.enabled) {
|
||||
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
if (options.verification.provider_enabled(Provider::kCUBLAS)) {
|
||||
if (options.verification.provider_enabled(library::Provider::kCUBLAS)) {
|
||||
|
||||
// Guard against unsupported cases
|
||||
auto const & gemm_desc = static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
if (cublas_satisfies(gemm_desc) != Status::kSuccess) {
|
||||
return true;
|
||||
}
|
||||
if (cublas_satisfies(gemm_desc) == Status::kSuccess) {
|
||||
|
||||
return verify_with_cublas_(
|
||||
options,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
problem_space,
|
||||
problem);
|
||||
// call cublas verification if supported
|
||||
verify_with_cublas_(
|
||||
options,
|
||||
report,
|
||||
device_context,
|
||||
operation,
|
||||
problem_space,
|
||||
problem);
|
||||
}
|
||||
|
||||
else {
|
||||
// set verification map for cublas to not supported
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported;
|
||||
}
|
||||
}
|
||||
#endif // #if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
// verification providers which are supported
|
||||
bool is_any_verification_run_passed = false;
|
||||
for(auto &m : results_.back().verification_map) {
|
||||
if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) {
|
||||
results_.back().disposition = m.second;
|
||||
return true;
|
||||
}
|
||||
if(!is_any_verification_run_passed && m.second == Disposition::kPassed) {
|
||||
is_any_verification_run_passed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(is_any_verification_run_passed) {
|
||||
results_.back().disposition = Disposition::kPassed;
|
||||
}
|
||||
}
|
||||
|
||||
// Return true means continue profiling
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Selects one or more cuBLAS algorithms.
|
||||
static void select_cublas_algorithms(
|
||||
std::vector<cublasGemmAlgo_t> &algorithms,
|
||||
Options const &options,
|
||||
library::GemmDescription const &op_desc) {
|
||||
|
||||
library::OpcodeClassID const & opcode_class =
|
||||
op_desc.tile_description.math_instruction.opcode_class;
|
||||
|
||||
switch (options.library.algorithm_mode) {
|
||||
case AlgorithmMode::kMatching:
|
||||
{
|
||||
algorithms.push_back(get_cublas_gemm_algo(
|
||||
op_desc.tile_description.threadblock_shape.m(),
|
||||
op_desc.tile_description.threadblock_shape.n(),
|
||||
op_desc.tile_description.threadblock_shape.k(),
|
||||
opcode_class));
|
||||
break;
|
||||
}
|
||||
|
||||
case AlgorithmMode::kBest:
|
||||
{
|
||||
// Choose first enumerated mode. If none are enumerated, choose based on opcode class
|
||||
// and evaluate all of them.
|
||||
|
||||
if (options.library.algorithms.empty()) {
|
||||
// Enumerate all algorithms
|
||||
if (opcode_class == library::OpcodeClassID::kSimt) {
|
||||
|
||||
for (int algo = CUBLAS_GEMM_DEFAULT;
|
||||
algo <= CUBLAS_GEMM_ALGO23;
|
||||
++algo) {
|
||||
|
||||
algorithms.push_back(cublasGemmAlgo_t(algo));
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP;
|
||||
++algo) {
|
||||
|
||||
algorithms.push_back(cublasGemmAlgo_t(algo));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Use the listed algorithms
|
||||
algorithms.reserve(options.library.algorithms.size());
|
||||
|
||||
for (int algo : options.library.algorithms) {
|
||||
algorithms.push_back(reinterpret_cast<cublasGemmAlgo_t const &>(algo));
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case AlgorithmMode::kDefault:
|
||||
{
|
||||
|
||||
// Use the library's default algorithm
|
||||
algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ?
|
||||
CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatcher to cublasGemmEx()
|
||||
struct cublasGemmExDispatcher {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
library::GemmConfiguration configuration;
|
||||
library::GemmArguments arguments;
|
||||
|
||||
cublasOperation_t trans_A;
|
||||
cublasOperation_t trans_B;
|
||||
cudaDataType_t data_type_A;
|
||||
cudaDataType_t data_type_B;
|
||||
cudaDataType_t data_type_C;
|
||||
cudaDataType_t compute_type;
|
||||
cublasGemmAlgo_t algo;
|
||||
Status status;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
cublasGemmExDispatcher(
|
||||
library::GemmDescription const &op_desc,
|
||||
library::GemmConfiguration configuration_,
|
||||
library::GemmArguments arguments_,
|
||||
cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT
|
||||
):
|
||||
configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) {
|
||||
|
||||
trans_A = get_cublas_transpose_operation(op_desc.A.layout);
|
||||
trans_B = get_cublas_transpose_operation(op_desc.B.layout);
|
||||
|
||||
bool good = true;
|
||||
good = (good && get_cublas_datatype(data_type_A, op_desc.A.element));
|
||||
good = (good && get_cublas_datatype(data_type_B, op_desc.B.element));
|
||||
good = (good && get_cublas_datatype(data_type_C, op_desc.C.element));
|
||||
|
||||
good = (good && get_cublas_datatype(
|
||||
compute_type,
|
||||
op_desc.tile_description.math_instruction.element_accumulator));
|
||||
|
||||
if (!good) {
|
||||
status = Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes GEMM using these arguments
|
||||
cublasStatus_t operator()(cublasHandle_t handle) {
|
||||
|
||||
return cublasGemmEx(
|
||||
handle,
|
||||
trans_A,
|
||||
trans_B,
|
||||
configuration.problem_size.m(),
|
||||
configuration.problem_size.n(),
|
||||
configuration.problem_size.k(),
|
||||
arguments.alpha,
|
||||
arguments.A,
|
||||
data_type_A,
|
||||
int(configuration.lda),
|
||||
arguments.B,
|
||||
data_type_B,
|
||||
int(configuration.ldb),
|
||||
arguments.beta,
|
||||
arguments.D,
|
||||
data_type_C,
|
||||
int(configuration.ldc),
|
||||
compute_type,
|
||||
algo
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#endif // CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
@ -632,14 +507,16 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
//
|
||||
// Construct cuBLAS operators
|
||||
//
|
||||
|
||||
CublasCreate handle;
|
||||
cublasStatus_t status = handle.get_cublas_create_status();
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
||||
results_.back().status = get_cutlass_status(status);
|
||||
results_.back().disposition = Disposition::kFailed;
|
||||
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -682,7 +559,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
);
|
||||
|
||||
if (gemm_op.status != Status::kSuccess) {
|
||||
results_.back().disposition = Disposition::kNotVerified;
|
||||
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -692,8 +570,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
|
||||
// Handle errors
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
results_.back().status = get_cutlass_status(status);
|
||||
results_.back().disposition = Disposition::kNotVerified;
|
||||
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -701,7 +579,7 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
// Verify results
|
||||
//
|
||||
|
||||
results_.back().disposition = compare_tensors(
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors(
|
||||
options,
|
||||
*gemm_workspace_.Computed,
|
||||
*gemm_workspace_.Reference
|
||||
@ -709,19 +587,18 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
|
||||
// Save workspace if incorrect
|
||||
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
|
||||
results_.back().disposition == Disposition::kIncorrect) {
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) {
|
||||
|
||||
save_workspace(
|
||||
device_context,
|
||||
options,
|
||||
gemm_desc,
|
||||
Provider::kCUTLASS,
|
||||
Provider::kCUBLAS);
|
||||
library::Provider::kCUTLASS,
|
||||
library::Provider::kCUBLAS);
|
||||
}
|
||||
}
|
||||
catch (...) {
|
||||
results_.back().disposition = Disposition::kFailed;
|
||||
results_.back().status = Status::kErrorNotSupported;
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed;
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -741,7 +618,7 @@ bool GemmOperationProfiler::profile(
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
if (options.profiling.provider_enabled(Provider::kCUTLASS)) {
|
||||
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
||||
|
||||
// Initialize structure containing GEMM arguments
|
||||
gemm_workspace_.arguments.A = gemm_workspace_.A->data();
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
// Profiler includes
|
||||
|
||||
@ -225,7 +225,7 @@ int OperationProfiler::profile_all(
|
||||
ProblemSpace problem_space(arguments_, options.cmdline);
|
||||
|
||||
// 1. Construct performance report
|
||||
PerformanceReport report(options, problem_space.argument_names());
|
||||
PerformanceReport report(options, problem_space.argument_names(), kind_);
|
||||
|
||||
// 2. For each problem in problem space
|
||||
ProblemSpace::Iterator problem_it = problem_space.begin();
|
||||
@ -269,7 +269,7 @@ int OperationProfiler::profile_all(
|
||||
if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// A. Initialize configuration
|
||||
Status status = this->initialize_configuration(
|
||||
options,
|
||||
@ -278,7 +278,7 @@ int OperationProfiler::profile_all(
|
||||
operation,
|
||||
problem_space,
|
||||
problem);
|
||||
|
||||
|
||||
if (status == Status::kErrorInternal) {
|
||||
// Stop profiling if there was an internal error
|
||||
return false;
|
||||
@ -341,7 +341,7 @@ int OperationProfiler::profile_all(
|
||||
device_context,
|
||||
options,
|
||||
operation->description(),
|
||||
Provider::kCUTLASS);
|
||||
library::Provider::kCUTLASS);
|
||||
}
|
||||
|
||||
//
|
||||
@ -434,8 +434,8 @@ void OperationProfiler::save_workspace(
|
||||
DeviceContext &device_context,
|
||||
Options const &options,
|
||||
library::OperationDescription const &desc,
|
||||
Provider provider,
|
||||
Provider verification_provider) {
|
||||
library::Provider provider,
|
||||
library::Provider verification_provider) {
|
||||
|
||||
for (auto const & named_allocation : device_context) {
|
||||
|
||||
@ -443,10 +443,10 @@ void OperationProfiler::save_workspace(
|
||||
|
||||
std::stringstream filename;
|
||||
|
||||
filename << desc.name << "_" << to_string(provider) << "_";
|
||||
filename << desc.name << "_" << library::to_string(provider) << "_";
|
||||
|
||||
if (verification_provider != Provider::kInvalid) {
|
||||
filename << "verified_by_" << to_string(verification_provider) << "_";
|
||||
if (verification_provider != library::Provider::kInvalid) {
|
||||
filename << "verified_by_" << library::to_string(verification_provider) << "_";
|
||||
}
|
||||
|
||||
filename << named_allocation.first + ".mat";
|
||||
@ -454,6 +454,7 @@ void OperationProfiler::save_workspace(
|
||||
std::ofstream out(filename.str());
|
||||
|
||||
allocation->write_tensor_csv(out);
|
||||
out << "\n";
|
||||
|
||||
if (options.report.verbose) {
|
||||
std::cout << "wrote '" << filename.str() << "'" << std::endl;
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
// Profiler includes
|
||||
@ -43,6 +44,7 @@
|
||||
#include "performance_result.h"
|
||||
#include "performance_report.h"
|
||||
#include "problem_space.h"
|
||||
#include "debug.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -192,8 +194,8 @@ public:
|
||||
DeviceContext &device_context,
|
||||
Options const &options,
|
||||
library::OperationDescription const &desc,
|
||||
Provider provider,
|
||||
Provider verification_provider = Provider::kInvalid);
|
||||
library::Provider provider,
|
||||
library::Provider verification_provider = library::Provider::kInvalid);
|
||||
|
||||
protected:
|
||||
|
||||
|
||||
@ -31,6 +31,8 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/version.h"
|
||||
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
#include "options.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -161,24 +163,30 @@ Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) {
|
||||
if (cmdline.check_cmd_line_flag("initialization-provider")) {
|
||||
std::string str;
|
||||
cmdline.get_cmd_line_argument("initialization-provider", str);
|
||||
provider = from_string<Provider>(str);
|
||||
if (provider == Provider::kInvalid) {
|
||||
provider = library::from_string<library::Provider>(str);
|
||||
if (provider == library::Provider::kInvalid) {
|
||||
enabled = false;
|
||||
}
|
||||
else if (provider != Provider::kReferenceHost && provider != Provider::kReferenceDevice) {
|
||||
else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) {
|
||||
throw std::runtime_error("Unsupported intialization provider specified.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
provider = Provider::kReferenceDevice;
|
||||
provider = library::Provider::kReferenceDevice;
|
||||
}
|
||||
|
||||
cmdline.get_cmd_line_argument("seed", seed, 2019);
|
||||
|
||||
if (cmdline.check_cmd_line_flag("dist")) {
|
||||
// user has set the data distribution (fix data distribution once set)
|
||||
fix_data_distribution = true;
|
||||
// set user provided data distribution
|
||||
get_distribution(cmdline, "dist", data_distribution);
|
||||
}
|
||||
else {
|
||||
// profiler choosen data distribution (allowed to change based on numeric types)
|
||||
fix_data_distribution = false;
|
||||
// set uniform data distribution with range [-4, 4]
|
||||
data_distribution.set_uniform(-4, 4, 0);
|
||||
}
|
||||
|
||||
@ -372,12 +380,12 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) {
|
||||
providers.clear();
|
||||
|
||||
for (auto const &token : tokens) {
|
||||
providers.push_back(from_string<Provider>(token));
|
||||
providers.push_back(library::from_string<library::Provider>(token));
|
||||
}
|
||||
}
|
||||
else {
|
||||
providers.push_back(Provider::kCUTLASS);
|
||||
providers.push_back(Provider::kCUBLAS);
|
||||
providers.push_back(library::Provider::kCUTLASS);
|
||||
providers.push_back(library::Provider::kCUBLAS);
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,18 +420,18 @@ void Options::Profiling::print_options(std::ostream &out, int indent) const {
|
||||
|
||||
int j = 0;
|
||||
for (auto const & provider : providers) {
|
||||
out << (j++ ? ", " : "") << to_string(provider);
|
||||
out << (j++ ? ", " : "") << library::to_string(provider);
|
||||
}
|
||||
out << "]\n";
|
||||
}
|
||||
|
||||
/// Returns true if a provider is enabled
|
||||
bool Options::Profiling::provider_enabled(Provider provider) const {
|
||||
bool Options::Profiling::provider_enabled(library::Provider provider) const {
|
||||
return std::find(providers.begin(), providers.end(), provider) != providers.end();
|
||||
}
|
||||
|
||||
/// Returns the index of a provider if its enabled
|
||||
size_t Options::Profiling::index(Provider provider) const {
|
||||
size_t Options::Profiling::index(library::Provider provider) const {
|
||||
size_t idx = 0;
|
||||
for (auto const & x : providers) {
|
||||
if (x == provider) {
|
||||
@ -461,14 +469,14 @@ Options::Verification::Verification(cutlass::CommandLine const &cmdline) {
|
||||
providers.clear();
|
||||
|
||||
for (auto const &token : tokens) {
|
||||
Provider provider = from_string<Provider>(token);
|
||||
if (provider != Provider::kInvalid) {
|
||||
library::Provider provider = library::from_string<library::Provider>(token);
|
||||
if (provider != library::Provider::kInvalid) {
|
||||
providers.push_back(provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
providers.push_back(Provider::kCUBLAS);
|
||||
providers.push_back(library::Provider::kCUBLAS);
|
||||
}
|
||||
}
|
||||
|
||||
@ -504,18 +512,18 @@ void Options::Verification::print_options(std::ostream &out, int indent) const {
|
||||
|
||||
int j = 0;
|
||||
for (auto const & provider : providers) {
|
||||
out << (j++ ? ", " : "") << to_string(provider);
|
||||
out << (j++ ? ", " : "") << library::to_string(provider);
|
||||
}
|
||||
out << "]\n";
|
||||
}
|
||||
|
||||
/// Returns true if a provider is enabled
|
||||
bool Options::Verification::provider_enabled(Provider provider) const {
|
||||
bool Options::Verification::provider_enabled(library::Provider provider) const {
|
||||
return std::find(providers.begin(), providers.end(), provider) != providers.end();
|
||||
}
|
||||
|
||||
/// Returns the index of a provider if its enabled
|
||||
size_t Options::Verification::index(Provider provider) const {
|
||||
size_t Options::Verification::index(library::Provider provider) const {
|
||||
size_t idx = 0;
|
||||
for (auto const & x : providers) {
|
||||
if (x == provider) {
|
||||
@ -658,7 +666,7 @@ Options::Options(cutlass::CommandLine const &cmdline):
|
||||
|
||||
// Prevent launches on the device for anything other than CUTLASS operation
|
||||
if (execution_mode == ExecutionMode::kTrace) {
|
||||
initialization.provider = Provider::kReferenceHost;
|
||||
initialization.provider = library::Provider::kReferenceHost;
|
||||
verification.enabled = false;
|
||||
profiling.enabled = false;
|
||||
}
|
||||
|
||||
@ -105,11 +105,15 @@ public:
|
||||
/// allocating tensors.
|
||||
bool enabled;
|
||||
|
||||
/// If true, data distribution is set by the user and is not allowed to change
|
||||
/// If false, data distribution is allowed to change based on element_type (library::NumericTypeID)
|
||||
bool fix_data_distribution;
|
||||
|
||||
/// Data distribution for input tensors
|
||||
Distribution data_distribution;
|
||||
|
||||
/// Source of random tensor elements
|
||||
Provider provider;
|
||||
library::Provider provider;
|
||||
|
||||
/// Random number generator seed.
|
||||
int seed;
|
||||
@ -162,10 +166,10 @@ public:
|
||||
void print_options(std::ostream &out, int indent = 0) const;
|
||||
|
||||
/// Returns true if a provider is enabled
|
||||
bool provider_enabled(Provider provider) const;
|
||||
bool provider_enabled(library::Provider provider) const;
|
||||
|
||||
/// Returns the index of a provider if its enabled
|
||||
size_t index(Provider provider) const;
|
||||
size_t index(library::Provider provider) const;
|
||||
};
|
||||
|
||||
/// Options related to profiling
|
||||
@ -196,10 +200,10 @@ public:
|
||||
void print_options(std::ostream &out, int indent = 0) const;
|
||||
|
||||
/// Returns true if a provider is enabled
|
||||
bool provider_enabled(Provider provider) const;
|
||||
bool provider_enabled(library::Provider provider) const;
|
||||
|
||||
/// Returns the index of a provider if its enabled
|
||||
size_t index(Provider provider) const;
|
||||
size_t index(library::Provider provider) const;
|
||||
};
|
||||
|
||||
/// Options related to reporting
|
||||
|
||||
@ -29,9 +29,15 @@
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
#include "performance_report.h"
|
||||
|
||||
#include "debug.h"
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
@ -57,12 +63,17 @@ namespace profiler {
|
||||
|
||||
PerformanceReport::PerformanceReport(
|
||||
Options const &options,
|
||||
std::vector<std::string> const &argument_names
|
||||
std::vector<std::string> const &argument_names,
|
||||
library::OperationKind const &op_kind
|
||||
):
|
||||
options_(options), argument_names_(argument_names), problem_index_(0), good_(true) {
|
||||
options_(options), argument_names_(argument_names), problem_index_(0), good_(true), op_kind_(op_kind) {
|
||||
|
||||
std::string file_name = options_.report.output_path.substr(0, options_.report.output_path.rfind("."));
|
||||
std::string file_extension = options_.report.output_path.substr(options_.report.output_path.rfind(".") + 1);
|
||||
op_file_name_ = file_name + "." + to_string(op_kind_) + "." + file_extension;
|
||||
|
||||
//
|
||||
// Open output file
|
||||
// Open output file for operation of PerformanceReport::op_kind
|
||||
//
|
||||
if (!options_.report.output_path.empty()) {
|
||||
|
||||
@ -70,17 +81,17 @@ PerformanceReport::PerformanceReport(
|
||||
|
||||
if (options_.report.append) {
|
||||
|
||||
std::ifstream test_output_file(options_.report.output_path.c_str());
|
||||
std::ifstream test_output_file(op_file_name_);
|
||||
|
||||
if (test_output_file.is_open()) {
|
||||
print_header = false;
|
||||
test_output_file.close();
|
||||
}
|
||||
|
||||
output_file_.open(options_.report.output_path.c_str(), std::ios::app);
|
||||
output_file_.open(op_file_name_, std::ios::app);
|
||||
}
|
||||
else {
|
||||
output_file_.open(options_.report.output_path.c_str());
|
||||
output_file_.open(op_file_name_);
|
||||
}
|
||||
|
||||
if (!output_file_.good()) {
|
||||
@ -148,7 +159,7 @@ void PerformanceReport::close() {
|
||||
}
|
||||
}
|
||||
else if (output_file_.is_open() && options_.report.verbose) {
|
||||
std::cout << "\n\nWrote results to '" << options_.report.output_path << "'" << std::endl;
|
||||
std::cout << "\n\nWrote results to '" << op_file_name_ << "'" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -184,19 +195,30 @@ std::ostream & PerformanceReport::print_result_pretty_(
|
||||
|
||||
out
|
||||
<< "\n"
|
||||
<< " Provider: " << SHELL_COLOR_BRIGHT() << to_string(result.provider, true) << SHELL_COLOR_END() << "\n"
|
||||
<< " Operation: " << result.operation_name << "\n\n"
|
||||
<< " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n"
|
||||
<< " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n";
|
||||
<< " Provider: " << SHELL_COLOR_BRIGHT() << library::to_string(result.provider, true) << SHELL_COLOR_END() << "\n"
|
||||
<< " Operation: " << result.operation_name << "\n\n"
|
||||
<< " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n"
|
||||
<< " Verification: " << SHELL_COLOR_BRIGHT() << (options_.verification.enabled ? "ON":"OFF") << SHELL_COLOR_END() << "\n"
|
||||
<< " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n\n";
|
||||
|
||||
// Display individual verification results for each verification-provider
|
||||
if (options_.verification.enabled) {
|
||||
|
||||
static int const indent_spaces = 22;
|
||||
|
||||
for(auto & m : result.verification_map) {
|
||||
out << std::right << std::setw(indent_spaces) << library::to_string(m.first, true) << ": " << to_string(m.second, true) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
<< "\n Arguments: ";
|
||||
<< "\n Arguments: ";
|
||||
|
||||
int column_idx = 0;
|
||||
for (auto const &arg : result.arguments) {
|
||||
if (!arg.second.empty()) {
|
||||
out << " --" << arg.first << "=" << arg.second;
|
||||
column_idx += 4 + arg.first.size() + arg.second.size();
|
||||
column_idx += int(4 + arg.first.size() + arg.second.size());
|
||||
if (column_idx > 90) {
|
||||
out << " \\\n ";
|
||||
column_idx = 0;
|
||||
@ -206,15 +228,15 @@ std::ostream & PerformanceReport::print_result_pretty_(
|
||||
out << "\n\n";
|
||||
|
||||
out
|
||||
<< " Bytes: " << result.bytes << " bytes\n"
|
||||
<< " FLOPs: " << result.flops << " flops\n\n";
|
||||
<< " Bytes: " << result.bytes << " bytes\n"
|
||||
<< " FLOPs: " << result.flops << " flops\n\n";
|
||||
|
||||
if (result.good()) {
|
||||
|
||||
out
|
||||
<< " Runtime: " << result.runtime << " ms\n"
|
||||
<< " Memory: " << result.gbytes_per_sec() << " GiB/s\n"
|
||||
<< "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n";
|
||||
<< " Runtime: " << result.runtime << " ms\n"
|
||||
<< " Memory: " << result.gbytes_per_sec() << " GiB/s\n"
|
||||
<< "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n";
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -31,10 +31,14 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
// CUTLASS Profiler includes
|
||||
#include "options.h"
|
||||
#include "enumerated_types.h"
|
||||
#include "performance_result.h"
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
@ -46,6 +50,12 @@ private:
|
||||
/// Reference to options
|
||||
Options const &options_;
|
||||
|
||||
/// Operation kind
|
||||
library::OperationKind op_kind_;
|
||||
|
||||
/// Operation file name containing performance report of op_kind
|
||||
std::string op_file_name_;
|
||||
|
||||
/// Output file containing results
|
||||
std::ofstream output_file_;
|
||||
|
||||
@ -63,7 +73,7 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
PerformanceReport(Options const &options, std::vector<std::string> const &argument_names);
|
||||
PerformanceReport(Options const &options, std::vector<std::string> const &argument_names, library::OperationKind const &op_kind);
|
||||
|
||||
bool good() const { return good_; }
|
||||
|
||||
|
||||
@ -32,8 +32,12 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
// CUTLASS Profiler includes
|
||||
#include "enumerated_types.h"
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
@ -45,15 +49,22 @@ struct PerformanceResult {
|
||||
/// Index of problem
|
||||
size_t problem_index;
|
||||
|
||||
/// Provider
|
||||
Provider provider;
|
||||
/// library::Provider
|
||||
library::Provider provider;
|
||||
|
||||
/// Outcome of test
|
||||
Disposition disposition;
|
||||
/// Operation kind
|
||||
library::OperationKind op_kind;
|
||||
|
||||
/// CUTLASS status result from kernels
|
||||
/// CUTLASS status result from kernels (success or failure)
|
||||
// Status does information on verification
|
||||
Status status;
|
||||
|
||||
/// Outcome of verification (worst case verification result)
|
||||
Disposition disposition;
|
||||
|
||||
/// Outcome of verification (all verification results)
|
||||
DispositionMap verification_map;
|
||||
|
||||
/// Operation object
|
||||
std::string operation_name;
|
||||
|
||||
@ -76,7 +87,8 @@ struct PerformanceResult {
|
||||
/// Ctor
|
||||
PerformanceResult():
|
||||
problem_index(0),
|
||||
provider(Provider::kInvalid),
|
||||
op_kind(library::OperationKind::kInvalid),
|
||||
provider(library::Provider::kInvalid),
|
||||
disposition(Disposition::kNotRun),
|
||||
status(Status::kInvalid),
|
||||
bytes(0),
|
||||
|
||||
@ -27,10 +27,11 @@
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/library/util.h"
|
||||
|
||||
#include "problem_space.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -849,17 +850,16 @@ bool arg_as_OpcodeClassID(
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null.
|
||||
bool arg_as_scalar(
|
||||
std::vector<uint8_t> &bytes,
|
||||
library::NumericTypeID numeric_type,
|
||||
KernelArgument::Value const *value_ptr) {
|
||||
|
||||
|
||||
if (value_ptr->not_null) {
|
||||
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) {
|
||||
int64_t int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value;
|
||||
|
||||
|
||||
// TODO - convert int64_t => destination type
|
||||
}
|
||||
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {
|
||||
|
||||
@ -31,6 +31,12 @@ target_include_directories(
|
||||
$<BUILD_INTERFACE:${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}>
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
cutlass_tools_util_includes
|
||||
INTERFACE
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:cublas>
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
|
||||
@ -40,3 +46,4 @@ install(
|
||||
TARGETS cutlass_tools_util_includes
|
||||
EXPORT NvidiaCutlass
|
||||
)
|
||||
|
||||
|
||||
@ -119,6 +119,16 @@ struct CommandLine {
|
||||
val = !(value == "0" || value == "false");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the value specified for a given commandline parameter --<flag>=<value>
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(const char* arg_name,
|
||||
value_t& val) const {
|
||||
|
||||
get_cmd_line_argument(arg_name, val, val);
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the value specified for a given commandline parameter --<flag>=<value>
|
||||
@ -126,7 +136,7 @@ struct CommandLine {
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(const char* arg_name,
|
||||
value_t& val,
|
||||
value_t const& _default = value_t()) const {
|
||||
value_t const& _default) const {
|
||||
using namespace std;
|
||||
|
||||
val = _default;
|
||||
|
||||
@ -40,10 +40,14 @@ namespace device_memory {
|
||||
/// Allocate a buffer of \p count elements of type \p T on the current CUDA device
|
||||
template <typename T>
|
||||
T* allocate(size_t count = 1) {
|
||||
|
||||
T* ptr = 0;
|
||||
size_t bytes = sizeof(T) * count;
|
||||
size_t bytes = 0;
|
||||
|
||||
bytes = count * sizeof(T);
|
||||
|
||||
cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);
|
||||
|
||||
if (cuda_error != cudaSuccess) {
|
||||
throw cuda_exception("Failed to allocate memory", cuda_error);
|
||||
}
|
||||
@ -111,13 +115,16 @@ void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) {
|
||||
copy_to_device(device_begin, &*begin, elements);
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* "Smart" device memory allocation
|
||||
******************************************************************************/
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device_memory
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Device allocation abstraction that tracks size and capacity
|
||||
template <typename T>
|
||||
struct allocation {
|
||||
class DeviceAllocation {
|
||||
public:
|
||||
|
||||
/// Delete functor for CUDA device memory
|
||||
struct deleter {
|
||||
void operator()(T* ptr) {
|
||||
@ -130,6 +137,7 @@ struct allocation {
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@ -140,23 +148,55 @@ struct allocation {
|
||||
/// Smart pointer
|
||||
platform::unique_ptr<T, deleter> smart_ptr;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Static methods
|
||||
//
|
||||
|
||||
/// Static member to compute the number of bytes needed for a given number of elements
|
||||
static size_t bytes(size_t elements) {
|
||||
if (sizeof_bits<T>::value < 8) {
|
||||
size_t const kElementsPerByte = 8 / sizeof_bits<T>::value;
|
||||
return elements / kElementsPerByte;
|
||||
}
|
||||
else {
|
||||
size_t const kBytesPerElement = sizeof_bits<T>::value / 8;
|
||||
return elements * kBytesPerElement;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor: allocates no memory
|
||||
allocation() : capacity(0) {}
|
||||
DeviceAllocation() : capacity(0) {}
|
||||
|
||||
/// Constructor: allocates \p capacity elements on the current CUDA device
|
||||
allocation(size_t _capacity) : smart_ptr(allocate<T>(_capacity)), capacity(_capacity) {}
|
||||
DeviceAllocation(size_t _capacity) :
|
||||
smart_ptr(device_memory::allocate<T>(_capacity)), capacity(_capacity) {}
|
||||
|
||||
/// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation
|
||||
DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {}
|
||||
|
||||
/// Copy constructor
|
||||
allocation(allocation const &p): smart_ptr(allocate<T>(p.capacity)), capacity(p.capacity) {
|
||||
copy_device_to_device(smart_ptr.get(), p.get(), capacity);
|
||||
DeviceAllocation(DeviceAllocation const &p):
|
||||
smart_ptr(device_memory::allocate<T>(p.capacity)), capacity(p.capacity) {
|
||||
|
||||
device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity);
|
||||
}
|
||||
|
||||
/// Move constructor
|
||||
DeviceAllocation(DeviceAllocation &&p): capacity(0) {
|
||||
std::swap(smart_ptr, p.smart_ptr);
|
||||
std::swap(capacity, p.capacity);
|
||||
}
|
||||
|
||||
/// Destructor
|
||||
~allocation() { reset(); }
|
||||
~DeviceAllocation() { reset(); }
|
||||
|
||||
/// Returns a pointer to the managed object
|
||||
T* get() const { return smart_ptr.get(); }
|
||||
@ -173,12 +213,41 @@ struct allocation {
|
||||
smart_ptr.reset();
|
||||
}
|
||||
|
||||
/// Deletes managed object, if owned, and allocates a new object
|
||||
void reset(size_t _capacity) {
|
||||
reset(device_memory::allocate<T>(_capacity), _capacity);
|
||||
}
|
||||
|
||||
/// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity
|
||||
void reset(T* _ptr, size_t _capacity) {
|
||||
smart_ptr.reset(_ptr);
|
||||
capacity = _capacity;
|
||||
}
|
||||
|
||||
/// Allocates a new buffer and copies the old buffer into it. The old buffer is then released.
|
||||
void reallocate(size_t new_capacity) {
|
||||
|
||||
platform::unique_ptr<T, deleter> new_allocation(device_memory::allocate<T>(new_capacity));
|
||||
|
||||
device_memory::copy_device_to_device(
|
||||
new_allocation.get(),
|
||||
smart_ptr.get(),
|
||||
std::min(new_capacity, capacity));
|
||||
|
||||
std::swap(smart_ptr, new_allocation);
|
||||
std::swap(new_capacity, capacity);
|
||||
}
|
||||
|
||||
/// Returns the number of elements
|
||||
size_t size() const {
|
||||
return capacity;
|
||||
}
|
||||
|
||||
/// Returns the number of bytes needed to store the allocation
|
||||
size_t bytes() const {
|
||||
return bytes(capacity);
|
||||
}
|
||||
|
||||
/// Returns a pointer to the object owned by *this
|
||||
T* operator->() const { return smart_ptr.get(); }
|
||||
|
||||
@ -189,15 +258,69 @@ struct allocation {
|
||||
const deleter& get_deleter() const { return smart_ptr.get_deleter(); }
|
||||
|
||||
/// Copies a device-side memory allocation
|
||||
allocation & operator=(allocation const &p) {
|
||||
DeviceAllocation & operator=(DeviceAllocation const &p) {
|
||||
if (capacity != p.capacity) {
|
||||
smart_ptr.reset(allocate<T>(p.capacity));
|
||||
smart_ptr.reset(device_memory::allocate<T>(p.capacity));
|
||||
capacity = p.capacity;
|
||||
}
|
||||
copy_device_to_device(smart_ptr.get(), p.get(), capacity);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Move assignment
|
||||
DeviceAllocation & operator=(DeviceAllocation && p) {
|
||||
std::swap(smart_ptr, p.smart_ptr);
|
||||
std::swap(capacity, p.capacity);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Copies the entire allocation from another location in device memory.
|
||||
void copy_from_device(T const *ptr) const {
|
||||
copy_from_device(ptr, capacity);
|
||||
}
|
||||
|
||||
/// Copies a given number of elements from device memory
|
||||
void copy_from_device(T const *ptr, size_t elements) const {
|
||||
device_memory::copy_device_to_device(get(), ptr, elements);
|
||||
}
|
||||
|
||||
void copy_to_device(T *ptr) const {
|
||||
copy_to_device(ptr, capacity);
|
||||
}
|
||||
|
||||
void copy_to_device(T *ptr, size_t elements) const {
|
||||
device_memory::copy_device_to_device(ptr, get(), elements);
|
||||
}
|
||||
|
||||
void copy_from_host(T const *ptr) const {
|
||||
copy_from_host(ptr, capacity);
|
||||
}
|
||||
|
||||
void copy_from_host(T const *ptr, size_t elements) const {
|
||||
device_memory::copy_to_device(get(), ptr, elements);
|
||||
}
|
||||
|
||||
void copy_to_host(T *ptr) const {
|
||||
copy_to_host(ptr, capacity);
|
||||
}
|
||||
|
||||
void copy_to_host(T *ptr, size_t elements) const {
|
||||
device_memory::copy_to_host(ptr, get(), elements);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace device_memory {
|
||||
|
||||
/// Device allocation abstraction that tracks size and capacity
|
||||
template <typename T>
|
||||
using allocation = cutlass::DeviceAllocation<T>;
|
||||
|
||||
} // namespace device_memory
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -99,7 +99,7 @@ public:
|
||||
using ConstReference = typename ConstTensorRef::Reference;
|
||||
|
||||
/// Used to handle packing of subbyte elements
|
||||
static int const kElementsPerStoredItem = (sizeof_bits<Element>::value < 8 ? sizeof(Element) * 8 / sizeof_bits<Element>::value : 1);
|
||||
static int const kElementsPerStoredItem = (sizeof_bits<Element>::value < 8 ? (8 / sizeof_bits<Element>::value) : 1);
|
||||
|
||||
private:
|
||||
|
||||
@ -232,7 +232,7 @@ public:
|
||||
|
||||
/// Returns the logical capacity based on extent and layout. May differ from size().
|
||||
LongIndex capacity() const {
|
||||
return layout_.capacity(extent_) * kElementsPerStoredItem;
|
||||
return layout_.capacity(extent_);
|
||||
}
|
||||
|
||||
/// Gets pointer to host data
|
||||
|
||||
423
tools/util/include/cutlass/util/host_tensor_planar_complex.h
Normal file
423
tools/util/include/cutlass/util/host_tensor_planar_complex.h
Normal file
@ -0,0 +1,423 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief HostTensor contributes management for both host and device memory.
|
||||
|
||||
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
|
||||
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
|
||||
for CUDA memcpy operations.
|
||||
|
||||
Call {host, device}_{data, ref, view}() for accessing host or device memory.
|
||||
|
||||
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
|
||||
#include "cutlass/tensor_ref_planar_complex.h"
|
||||
#include "cutlass/tensor_view_planar_complex.h"
|
||||
|
||||
#include "device_memory.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Host tensor
|
||||
template <
|
||||
/// Data type of element stored within tensor (concept: NumericType)
|
||||
typename Element_,
|
||||
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
||||
typename Layout_
|
||||
>
|
||||
class HostTensorPlanarComplex {
|
||||
public:
|
||||
|
||||
/// Data type of individual access
|
||||
using Element = Element_;
|
||||
|
||||
/// Mapping function from logical coordinate to linear memory
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = Layout::kRank;
|
||||
|
||||
/// Index type
|
||||
using Index = typename Layout::Index;
|
||||
|
||||
/// Long index used for pointer offsets
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
/// Layout's stride vector
|
||||
using Stride = typename Layout::Stride;
|
||||
|
||||
/// Tensor reference to device memory
|
||||
using TensorRef = TensorRefPlanarComplex<Element, Layout>;
|
||||
|
||||
/// Tensor reference to constant device memory
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
/// Tensor reference to device memory
|
||||
using TensorView = TensorViewPlanarComplex<Element, Layout>;
|
||||
|
||||
/// Tensor reference to constant device memory
|
||||
using ConstTensorView = typename TensorView::ConstTensorView;
|
||||
|
||||
/// Reference to element in tensor
|
||||
using Reference = typename TensorRef::Reference;
|
||||
|
||||
/// Constant reference to element in tensor
|
||||
using ConstReference = typename ConstTensorRef::Reference;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Extent of tensor in logical dimensions
|
||||
TensorCoord extent_;
|
||||
|
||||
/// Layout object
|
||||
Layout layout_;
|
||||
|
||||
/// Host-side memory allocation
|
||||
std::vector<Element> host_;
|
||||
|
||||
/// Device-side memory
|
||||
device_memory::allocation<Element> device_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
HostTensorPlanarComplex() {}
|
||||
|
||||
/// Constructs a tensor given an extent. Assumes a packed layout
|
||||
HostTensorPlanarComplex(
|
||||
TensorCoord const &extent,
|
||||
bool device_backed = true
|
||||
) {
|
||||
|
||||
this->reset(extent, Layout::packed(extent), device_backed);
|
||||
}
|
||||
|
||||
/// Constructs a tensor given an extent and layout
|
||||
HostTensorPlanarComplex(
|
||||
TensorCoord const &extent,
|
||||
Layout const &layout,
|
||||
bool device_backed = true
|
||||
) {
|
||||
|
||||
this->reset(extent, layout, device_backed);
|
||||
}
|
||||
|
||||
~HostTensorPlanarComplex() { }
|
||||
|
||||
/// Clears the HostTensor allocation to size/capacity = 0
|
||||
void reset() {
|
||||
extent_ = TensorCoord();
|
||||
layout_ = Layout::packed(extent_);
|
||||
|
||||
host_.clear();
|
||||
device_.reset();
|
||||
}
|
||||
|
||||
/// Resizes internal memory allocations without affecting layout or extent
|
||||
void reserve(
|
||||
size_t count, ///< size of tensor in elements
|
||||
bool device_backed_ = true) { ///< if true, device memory is also allocated
|
||||
|
||||
device_.reset();
|
||||
host_.clear();
|
||||
|
||||
host_.resize(count * 2);
|
||||
|
||||
// Allocate memory
|
||||
Element* device_memory = nullptr;
|
||||
if (device_backed_) {
|
||||
device_memory = device_memory::allocate<Element>(count * 2);
|
||||
}
|
||||
device_.reset(device_memory, device_backed_ ? count * 2 : 0);
|
||||
}
|
||||
|
||||
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
||||
/// extent and layout.
|
||||
void reset(
|
||||
TensorCoord const &extent, ///< extent of logical tensor
|
||||
Layout const &layout, ///< layout object of tensor
|
||||
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
||||
|
||||
extent_ = extent;
|
||||
layout_ = layout;
|
||||
|
||||
reserve(size_t(layout_.capacity(extent_)), device_backed_);
|
||||
}
|
||||
|
||||
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
||||
/// extent and layout. Assumes a packed tensor configuration.
|
||||
void reset(
|
||||
TensorCoord const &extent, ///< extent of logical tensor
|
||||
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
||||
|
||||
reset(extent, Layout::packed(extent), device_backed_);
|
||||
}
|
||||
|
||||
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
||||
/// To force allocation, call reset().
|
||||
void resize(
|
||||
TensorCoord const &extent, ///< extent of logical tensor
|
||||
Layout const &layout, ///< layout object of tensor
|
||||
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
||||
|
||||
extent_ = extent;
|
||||
layout_ = layout;
|
||||
|
||||
LongIndex new_size = size_t(layout_.capacity(extent_));
|
||||
|
||||
if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
|
||||
reserve(new_size);
|
||||
}
|
||||
}
|
||||
|
||||
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
||||
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
|
||||
void resize(
|
||||
TensorCoord const &extent, ///< extent of logical tensor
|
||||
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
||||
|
||||
resize(extent, Layout::packed(extent), device_backed_);
|
||||
}
|
||||
|
||||
/// Returns the number of elements stored in the host tensor
|
||||
size_t size() const {
|
||||
return host_.size() / 2;
|
||||
}
|
||||
|
||||
/// Returns the logical capacity based on extent and layout. May differ from size().
|
||||
LongIndex capacity() const {
|
||||
return layout_.capacity(extent_);
|
||||
}
|
||||
|
||||
/// Stride between real and imaginary parts
|
||||
LongIndex imaginary_stride() const {
|
||||
return host_.size() / 2;
|
||||
}
|
||||
|
||||
/// Gets pointer to host data
|
||||
Element * host_data() { return host_.data(); }
|
||||
|
||||
/// Gets pointer to host data imaginary part
|
||||
Element * host_data_imag() { return host_.data() + imaginary_stride(); }
|
||||
|
||||
/// Gets pointer to host data with a pointer offset
|
||||
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
|
||||
|
||||
/// Gets pointer to host data with a pointer offset
|
||||
Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
|
||||
|
||||
/// Gets a reference to an element in host memory
|
||||
Reference host_data(LongIndex idx) {
|
||||
return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
|
||||
}
|
||||
|
||||
/// Gets pointer to host data
|
||||
Element const * host_data() const { return host_.data(); }
|
||||
|
||||
/// Gets pointer to host data imaginary part
|
||||
Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
|
||||
|
||||
/// Gets a constant reference to an element in host memory
|
||||
ConstReference host_data(LongIndex idx) const {
|
||||
return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
|
||||
}
|
||||
|
||||
/// Gets pointer to device data
|
||||
Element * device_data() { return device_.get(); }
|
||||
|
||||
/// Gets pointer to device data with a pointer offset
|
||||
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
|
||||
|
||||
/// Gets pointer to device data
|
||||
Element const * device_data() const { return device_.get(); }
|
||||
|
||||
/// Gets pointer to device data with a pointer offset
|
||||
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef host_ref(LongIndex ptr_element_offset=0) {
|
||||
return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
||||
}
|
||||
|
||||
/// Returns a tensor reference to the real part of the tensor
|
||||
cutlass::TensorRef<Element, Layout> host_ref_real() {
|
||||
return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
|
||||
}
|
||||
|
||||
/// Returns a tensor reference to the real part of the tensor
|
||||
cutlass::TensorRef<Element, Layout> host_ref_imag() {
|
||||
return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
|
||||
return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorRef device_ref(LongIndex ptr_element_offset=0) {
|
||||
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
|
||||
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
||||
}
|
||||
|
||||
/// Returns a tensor reference to the real part of the tensor
|
||||
cutlass::TensorRef<Element, Layout> device_ref_real() {
|
||||
return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
|
||||
}
|
||||
|
||||
/// Returns a tensor reference to the real part of the tensor
|
||||
cutlass::TensorRef<Element, Layout> device_ref_imag() {
|
||||
return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorView host_view(LongIndex ptr_element_offset=0) {
|
||||
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
|
||||
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
cutlass::TensorView<Element, Layout> host_view_real() {
|
||||
return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
cutlass::TensorView<Element, Layout> host_view_imag() {
|
||||
return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
TensorView device_view(LongIndex ptr_element_offset=0) {
|
||||
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
|
||||
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
cutlass::TensorView<Element, Layout> device_view_real() {
|
||||
return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
cutlass::TensorView<Element, Layout> device_view_imag() {
|
||||
return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
||||
}
|
||||
|
||||
/// Returns true if device memory is allocated
|
||||
bool device_backed() const {
|
||||
return (device_.get() == nullptr) ? false : true;
|
||||
}
|
||||
|
||||
/// Returns the layout object
|
||||
Layout layout() const {
|
||||
return layout_;
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride vector
|
||||
Stride stride() const {
|
||||
return layout_.stride();
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride in a given physical dimension
|
||||
Index stride(int dim) const {
|
||||
return layout_.stride().at(dim);
|
||||
}
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
LongIndex offset(TensorCoord const& coord) const {
|
||||
return layout_(coord);
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at the logical Coord in host memory
|
||||
Reference at(TensorCoord const& coord) {
|
||||
return host_data(offset(coord));
|
||||
}
|
||||
|
||||
/// Returns a const reference to the element at the logical Coord in host memory
|
||||
ConstReference at(TensorCoord const& coord) const {
|
||||
return host_data(offset(coord));
|
||||
}
|
||||
|
||||
/// Returns the extent of the tensor
|
||||
TensorCoord extent() const {
|
||||
return extent_;
|
||||
}
|
||||
|
||||
/// Returns the extent of the tensor
|
||||
TensorCoord & extent() {
|
||||
return extent_;
|
||||
}
|
||||
|
||||
/// Copies data from device to host
|
||||
void sync_host() {
|
||||
if (device_backed()) {
|
||||
device_memory::copy_to_host(
|
||||
host_data(), device_data(), imaginary_stride() * 2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies data from host to device
|
||||
void sync_device() {
|
||||
if (device_backed()) {
|
||||
device_memory::copy_to_device(
|
||||
device_data(), host_data(), imaginary_stride() * 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -0,0 +1,306 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued GEMM in device code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_ref_planar_complex.h"
|
||||
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static int const kGemmPlanarComplexBlockSize = 4;
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
||||
>
|
||||
__global__ void GemmPlanarComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
complex<ScalarType> alpha,
|
||||
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
complex<ScalarType> beta,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
||||
complex<ComputeType> initial_accum) {
|
||||
|
||||
int const kMblock = kGemmPlanarComplexBlockSize;
|
||||
int const kNblock = kGemmPlanarComplexBlockSize;
|
||||
|
||||
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
||||
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
||||
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
|
||||
complex<ComputeType> accum[kMblock][kNblock];
|
||||
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
||||
|
||||
complex<ComputeType> a = complex<ComputeType>{
|
||||
ComputeType(a_ik.real()),
|
||||
ComputeType(a_ik.imag())
|
||||
};
|
||||
|
||||
complex<ComputeType> b = complex<ComputeType>{
|
||||
ComputeType(b_kj.real()),
|
||||
ComputeType(b_kj.imag())
|
||||
};
|
||||
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a = conj(a);
|
||||
}
|
||||
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b = conj(b);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
complex<ScalarType> acc{
|
||||
ScalarType(accum[i][j].real()),
|
||||
ScalarType(accum[i][j].imag())
|
||||
};
|
||||
|
||||
ComplexC c_ij = ComplexC();
|
||||
|
||||
if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
|
||||
c_ij = tensor_c.at(coord);
|
||||
}
|
||||
|
||||
complex<ScalarType> src{
|
||||
ScalarType(c_ij.real()),
|
||||
ScalarType(c_ij.imag())
|
||||
};
|
||||
|
||||
complex<ScalarType> result = alpha * acc + beta * src;
|
||||
|
||||
ComplexC d_ij;
|
||||
|
||||
d_ij.real() = convert_op(result.real());
|
||||
d_ij.imag() = convert_op(result.imag());;
|
||||
|
||||
tensor_d.at(coord) = d_ij;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
||||
>
|
||||
void GemmPlanarComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
complex<ScalarType> alpha,
|
||||
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
complex<ScalarType> beta,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
||||
complex<ComputeType> initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
int const kMblock = kernel::kGemmPlanarComplexBlockSize;
|
||||
int const kNblock = kernel::kGemmPlanarComplexBlockSize;
|
||||
|
||||
dim3 block(16, 8);
|
||||
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
1);
|
||||
|
||||
kernel::GemmPlanarComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ScalarType,
|
||||
ComputeType,
|
||||
ConvertOp,
|
||||
InnerProductOp
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
transform_a,
|
||||
tensor_b,
|
||||
transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
initial_accum
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType
|
||||
>
|
||||
void GemmPlanarComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
complex<ScalarType> alpha,
|
||||
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
complex<ScalarType> beta,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
||||
|
||||
GemmPlanarComplex(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a, transform_a,
|
||||
tensor_b, transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
complex<ScalarType>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -43,12 +43,12 @@
|
||||
#endif
|
||||
|
||||
// CUDA includes
|
||||
#include <cublas_v2.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
// Cutlass includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
#include "cutlass/util/reference/device/tensor_foreach.h"
|
||||
@ -169,6 +169,95 @@ struct RandomGaussianFunc {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Real>
|
||||
struct RandomGaussianFunc<complex<Real>> {
|
||||
|
||||
using Element = complex<Real>;
|
||||
using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
|
||||
using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
uint64_t seed;
|
||||
FloatType mean;
|
||||
FloatType stddev;
|
||||
int int_scale;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
Params(
|
||||
uint64_t seed_ = 0,
|
||||
Real mean_ = 0,
|
||||
Real stddev_ = 1,
|
||||
int int_scale_ = -1
|
||||
):
|
||||
seed(seed_),
|
||||
mean(static_cast<FloatType>(mean_)),
|
||||
stddev(static_cast<FloatType>(stddev_)),
|
||||
int_scale(int_scale_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// RNG state object
|
||||
curandState_t rng_state;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Device-side initialization of RNG
|
||||
CUTLASS_DEVICE
|
||||
RandomGaussianFunc(Params const ¶ms): params(params) {
|
||||
|
||||
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
curand_init(params.seed, gtid, 0, &rng_state);
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
CUTLASS_DEVICE
|
||||
Element operator()() {
|
||||
|
||||
FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
|
||||
FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
|
||||
rnd_r = params.mean + params.stddev * rnd_r;
|
||||
rnd_i = params.mean + params.stddev * rnd_i;
|
||||
|
||||
Element result;
|
||||
if (params.int_scale >= 0) {
|
||||
rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale)));
|
||||
rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale)));
|
||||
|
||||
result = {
|
||||
Real(rnd_r / FloatType(IntType(1) << params.int_scale)),
|
||||
Real(rnd_i / FloatType(IntType(1) << params.int_scale))
|
||||
};
|
||||
}
|
||||
else {
|
||||
result = Element(Real(rnd_r), Real(rnd_i));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
@ -269,12 +358,12 @@ template <typename Element> ///< Element type
|
||||
void BlockFillRandomGaussian(
|
||||
Element *ptr,
|
||||
size_t capacity,
|
||||
uint64_t seed, ///< seed for RNG
|
||||
Element mean = Element(0), ///< Gaussian distribution's mean
|
||||
Element stddev = Element(1), ///< Gaussian distribution's standard deviation
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
uint64_t seed, ///< seed for RNG
|
||||
typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
|
||||
typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
|
||||
using RandomFunc = detail::RandomGaussianFunc<Element>;
|
||||
|
||||
@ -383,6 +472,111 @@ struct RandomUniformFunc {
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <typename Real> ///< Layout function
|
||||
struct RandomUniformFunc<complex<Real>> {
|
||||
|
||||
using Element = complex<Real>;
|
||||
|
||||
using FloatType = typename std::conditional<
|
||||
(sizeof(Real) > 4),
|
||||
double,
|
||||
float>::type;
|
||||
|
||||
using IntType = typename std::conditional<
|
||||
(sizeof(Real) > 4),
|
||||
int64_t,
|
||||
int>::type;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
uint64_t seed;
|
||||
FloatType range;
|
||||
FloatType min;
|
||||
int int_scale;
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construction of Gaussian RNG functor.
|
||||
Params(
|
||||
uint64_t seed_ = 0,
|
||||
FloatType max = 1,
|
||||
FloatType min_ = 0,
|
||||
int int_scale_ = -1
|
||||
):
|
||||
seed(seed_),
|
||||
range(static_cast<FloatType>(max - min_)),
|
||||
min(static_cast<FloatType>(min_)),
|
||||
int_scale(int_scale_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// RNG state object
|
||||
curandState_t rng_state;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Device-side initialization of RNG
|
||||
CUTLASS_DEVICE
|
||||
RandomUniformFunc(Params const ¶ms): params(params) {
|
||||
|
||||
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
curand_init(params.seed, gtid, 0, &rng_state);
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
CUTLASS_DEVICE
|
||||
Element operator()() {
|
||||
|
||||
FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
|
||||
FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
|
||||
|
||||
rnd_r = params.min + params.range * rnd_r;
|
||||
rnd_i = params.min + params.range * rnd_i;
|
||||
|
||||
// Random values are cast to integer after scaling by a power of two to facilitate error
|
||||
// testing
|
||||
Element result;
|
||||
|
||||
if (params.int_scale >= 0) {
|
||||
rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale)));
|
||||
rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale)));
|
||||
|
||||
result = {
|
||||
Real(rnd_r / FloatType(IntType(1) << params.int_scale)),
|
||||
Real(rnd_i / FloatType(IntType(1) << params.int_scale))
|
||||
};
|
||||
}
|
||||
else {
|
||||
result = Element(Real(rnd_r), Real(rnd_i));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
@ -489,8 +683,8 @@ void BlockFillRandomUniform(
|
||||
Element *ptr,
|
||||
size_t capacity,
|
||||
uint64_t seed, ///< seed for RNG
|
||||
Element max = Element(1), ///< upper bound of distribution
|
||||
Element min = Element(0), ///< lower bound for distribution
|
||||
typename RealType<Element>::Type max, ///< upper bound of distribution
|
||||
typename RealType<Element>::Type min, ///< lower bound for distribution
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
@ -976,13 +1170,15 @@ void BlockFillRandom(
|
||||
uint64_t seed,
|
||||
Distribution dist) {
|
||||
|
||||
using Real = typename RealType<Element>::Type;
|
||||
|
||||
if (dist.kind == Distribution::Gaussian) {
|
||||
BlockFillRandomGaussian<Element>(
|
||||
ptr,
|
||||
capacity,
|
||||
seed,
|
||||
static_cast<Element>(dist.gaussian.mean),
|
||||
static_cast<Element>(dist.gaussian.stddev),
|
||||
static_cast<Real>(dist.gaussian.mean),
|
||||
static_cast<Real>(dist.gaussian.stddev),
|
||||
dist.int_scale);
|
||||
}
|
||||
else if (dist.kind == Distribution::Uniform) {
|
||||
@ -990,8 +1186,8 @@ void BlockFillRandom(
|
||||
ptr,
|
||||
capacity,
|
||||
seed,
|
||||
static_cast<Element>(dist.uniform.max),
|
||||
static_cast<Element>(dist.uniform.min),
|
||||
static_cast<Real>(dist.uniform.max),
|
||||
static_cast<Real>(dist.uniform.min),
|
||||
dist.int_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,6 +72,7 @@ void GemmComplex(
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d,
|
||||
ComputeType initial_accum) {
|
||||
|
||||
static_assert(
|
||||
@ -138,7 +139,7 @@ void GemmComplex(
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
tensor_c.at(coord) = convert_op(
|
||||
tensor_d.at(coord) = convert_op(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * ScalarType(tensor_c.at(coord)));
|
||||
}
|
||||
@ -171,9 +172,10 @@ void GemmComplex(
|
||||
TensorRef<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
ScalarType beta,
|
||||
TensorRef<ElementC, LayoutC> tensor_c) {
|
||||
TensorRef<ElementC, LayoutC> tensor_c,
|
||||
TensorRef<ElementC, LayoutC> tensor_d) {
|
||||
|
||||
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0));
|
||||
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,223 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for complex-valued GEMM in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_ref_planar_complex.h"
|
||||
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType,
|
||||
typename ComputeType,
|
||||
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
||||
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
||||
>
|
||||
void GemmPlanarComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
complex<ScalarType> alpha,
|
||||
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
complex<ScalarType> beta,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
||||
complex<ComputeType> initial_accum) {
|
||||
|
||||
static_assert(
|
||||
LayoutA::kRank == 2 &&
|
||||
LayoutB::kRank == 2 &&
|
||||
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
||||
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
||||
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 16;
|
||||
int const Nblock = 16;
|
||||
|
||||
ConvertOp convert_op;
|
||||
InnerProductOp inner_product_op;
|
||||
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
complex<ComputeType> accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
||||
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
||||
|
||||
complex<ComputeType> a = complex<ComputeType>{
|
||||
ComputeType(a_ik.real()),
|
||||
ComputeType(a_ik.imag())
|
||||
};
|
||||
|
||||
complex<ComputeType> b = complex<ComputeType>{
|
||||
ComputeType(b_kj.real()),
|
||||
ComputeType(b_kj.imag())
|
||||
};
|
||||
|
||||
if (transform_a == ComplexTransform::kConjugate) {
|
||||
a = conj(a);
|
||||
}
|
||||
|
||||
if (transform_b == ComplexTransform::kConjugate) {
|
||||
b = conj(b);
|
||||
}
|
||||
|
||||
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
complex<ScalarType> acc{
|
||||
ScalarType(accum[i][j].real()),
|
||||
ScalarType(accum[i][j].imag())
|
||||
};
|
||||
|
||||
ComplexC d_ij = tensor_c.at(coord);
|
||||
|
||||
complex<ScalarType> src{
|
||||
ScalarType(d_ij.real()),
|
||||
ScalarType(d_ij.imag())
|
||||
};
|
||||
|
||||
complex<ScalarType> result = alpha * acc + beta * src;
|
||||
|
||||
d_ij.real() = convert_op(result.real());
|
||||
d_ij.imag() = convert_op(result.imag());;
|
||||
|
||||
tensor_d.at(coord) = d_ij;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ScalarType
|
||||
>
|
||||
void GemmPlanarComplex(
|
||||
gemm::GemmCoord problem_size,
|
||||
complex<ScalarType> alpha,
|
||||
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
||||
ComplexTransform transform_a,
|
||||
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
||||
ComplexTransform transform_b,
|
||||
complex<ScalarType> beta,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
||||
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
||||
|
||||
GemmPlanarComplex(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a, transform_a,
|
||||
tensor_b, transform_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
tensor_d,
|
||||
complex<ScalarType>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -33,6 +33,9 @@
|
||||
|
||||
// Cutlass includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/tensor_view_planar_complex.h"
|
||||
|
||||
#include "cutlass/util/distribution.h"
|
||||
//#include "cutlass/util/type_traits.h"
|
||||
#include "tensor_foreach.h"
|
||||
@ -112,6 +115,46 @@ bool TensorEquals(
|
||||
return bool(func);
|
||||
}
|
||||
|
||||
/// Returns true if two tensor views are equal.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
bool TensorEquals(
|
||||
TensorViewPlanarComplex<Element, Layout> const &lhs,
|
||||
TensorViewPlanarComplex<Element, Layout> const &rhs) {
|
||||
|
||||
// Extents must be identical
|
||||
if (lhs.extent() != rhs.extent()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
detail::TensorEqualsFunc<Element, Layout> real_func(
|
||||
{lhs.data(), lhs.layout(), lhs.extent()},
|
||||
{rhs.data(), rhs.layout(), rhs.extent()}
|
||||
);
|
||||
|
||||
TensorForEach(
|
||||
lhs.extent(),
|
||||
real_func
|
||||
);
|
||||
|
||||
if (!bool(real_func)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
detail::TensorEqualsFunc<Element, Layout> imag_func(
|
||||
{lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
|
||||
{rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}
|
||||
);
|
||||
|
||||
TensorForEach(
|
||||
lhs.extent(),
|
||||
imag_func
|
||||
);
|
||||
|
||||
return bool(imag_func);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -137,6 +180,17 @@ bool TensorNotEquals(
|
||||
return !bool(func);
|
||||
}
|
||||
|
||||
/// Returns true if two tensor views are equal.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
bool TensorNotEquals(
|
||||
TensorViewPlanarComplex<Element, Layout> const &lhs,
|
||||
TensorViewPlanarComplex<Element, Layout> const &rhs) {
|
||||
|
||||
return !TensorEquals(lhs, rhs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -38,6 +38,8 @@
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/tensor_view_planar_complex.h"
|
||||
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "tensor_foreach.h"
|
||||
@ -101,6 +103,18 @@ void TensorFill(
|
||||
);
|
||||
}
|
||||
|
||||
/// Fills a tensor with a uniform value
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFill(
|
||||
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
||||
cutlass::complex<Element> val = cutlass::complex<Element>(0)) { ///< value to uniformly fill it with
|
||||
|
||||
TensorFill(dst.view_real(), val.real());
|
||||
TensorFill(dst.view_imag(), val.imag());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -268,6 +282,23 @@ void TensorFillRandomGaussian(
|
||||
);
|
||||
}
|
||||
|
||||
/// Fills a tensor with random values with a Gaussian distribution.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillRandomGaussian(
|
||||
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
double mean = 0, ///< Gaussian distribution's mean
|
||||
double stddev = 1, ///< Gaussian distribution's standard deviation
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
|
||||
TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits);
|
||||
TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with random values with a Gaussian distribution.
|
||||
@ -461,6 +492,23 @@ void TensorFillRandomUniform(
|
||||
);
|
||||
}
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution.
|
||||
template <
|
||||
typename Element, ///< Element type
|
||||
typename Layout> ///< Layout function
|
||||
void TensorFillRandomUniform(
|
||||
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
double max = 1, ///< upper bound of distribution
|
||||
double min = 0, ///< lower bound for distribution
|
||||
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
|
||||
TensorFillRandomUniform(dst.view_real(), seed, max, min, bits);
|
||||
TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution.
|
||||
@ -774,6 +822,27 @@ void BlockFillSequential(
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills a block of data with sequential elements
|
||||
template <
|
||||
typename Element
|
||||
>
|
||||
void BlockFillSequentialModN(
|
||||
Element *ptr,
|
||||
int64_t capacity,
|
||||
int64_t mod,
|
||||
int64_t v = int64_t(1),
|
||||
int64_t s = int64_t(0)) {
|
||||
int i = 0;
|
||||
|
||||
while (i < capacity) {
|
||||
cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
|
||||
8)>::get(ptr, i) = Element(s);
|
||||
|
||||
s = int64_t(s + v) % mod;
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -26,6 +26,8 @@
|
||||
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/tensor_view_planar_complex.h"
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -87,13 +89,13 @@ inline std::ostream & TensorView_WriteRank(
|
||||
coord[rank] = idx;
|
||||
|
||||
if (rank + 2 == Layout::kRank) {
|
||||
// Write least significant ranks asa matrix with rows delimited by ";\n"
|
||||
out << (idx ? ";\n" : "");
|
||||
// Write least significant ranks asa matrix with rows delimited by "\n"
|
||||
out << (idx ? ",\n" : "");
|
||||
TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width);
|
||||
}
|
||||
else {
|
||||
// Higher ranks are separated by newlines
|
||||
out << (idx ? "\n" : "");
|
||||
out << (idx ? ",\n\n" : "");
|
||||
TensorView_WriteRank(out, view, coord, rank + 1, width);
|
||||
}
|
||||
}
|
||||
@ -101,6 +103,76 @@ inline std::ostream & TensorView_WriteRank(
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Helper to write the least significant rank of a TensorView
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank(
|
||||
std::ostream& out,
|
||||
TensorViewPlanarComplex<Element, Layout> const& view,
|
||||
Coord<Layout::kRank> const &start_coord,
|
||||
int rank,
|
||||
std::streamsize width) {
|
||||
|
||||
for (int idx = 0; idx < view.extent(rank); ++idx) {
|
||||
|
||||
Coord<Layout::kRank> coord(start_coord);
|
||||
coord[rank] = idx;
|
||||
|
||||
if (idx) {
|
||||
out.width(0);
|
||||
out << ", ";
|
||||
}
|
||||
if (idx || coord) {
|
||||
out.width(width);
|
||||
}
|
||||
|
||||
complex<Element> x = view.at(coord);
|
||||
out << x;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Helper to write a rank of a TensorView
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
inline std::ostream & TensorViewPlanarComplex_WriteRank(
|
||||
std::ostream& out,
|
||||
TensorViewPlanarComplex<Element, Layout> const& view,
|
||||
Coord<Layout::kRank> const &start_coord,
|
||||
int rank,
|
||||
std::streamsize width) {
|
||||
|
||||
// If called on the least significant rank, write the result as a row
|
||||
if (rank + 1 == Layout::kRank) {
|
||||
return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width);
|
||||
}
|
||||
|
||||
// Otherwise, write a sequence of rows and newlines
|
||||
for (int idx = 0; idx < view.extent(rank); ++idx) {
|
||||
|
||||
Coord<Layout::kRank> coord(start_coord);
|
||||
coord[rank] = idx;
|
||||
|
||||
if (rank + 2 == Layout::kRank) {
|
||||
// Write least significant ranks asa matrix with rows delimited by ";\n"
|
||||
out << (idx ? ";\n" : "");
|
||||
TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width);
|
||||
}
|
||||
else {
|
||||
// Higher ranks are separated by newlines
|
||||
out << (idx ? "\n" : "");
|
||||
TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width);
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -143,4 +215,42 @@ inline std::ostream& operator<<(
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Prints human-readable representation of a TensorView to an ostream
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
inline std::ostream& TensorViewWrite(
|
||||
std::ostream& out,
|
||||
TensorViewPlanarComplex<Element, Layout> const& view) {
|
||||
|
||||
// Prints a TensorView according to the following conventions:
|
||||
// - least significant rank is printed as rows separated by ";\n"
|
||||
// - all greater ranks are delimited with newlines
|
||||
//
|
||||
// The result is effectively a whitespace-delimited series of 2D matrices.
|
||||
|
||||
return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord<Layout::kRank>(), 0, out.width());
|
||||
}
|
||||
|
||||
/// Prints human-readable representation of a TensorView to an ostream
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
TensorViewPlanarComplex<Element, Layout> const& view) {
|
||||
|
||||
// Prints a TensorView according to the following conventions:
|
||||
// - least significant rank is printed as rows separated by ";\n"
|
||||
// - all greater ranks are delimited with newlines
|
||||
//
|
||||
// The result is effectively a whitespace-delimited series of 2D matrices.
|
||||
|
||||
return TensorViewWrite(out, view);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
Reference in New Issue
Block a user