v4.3 update. (#2709)
* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
@ -35,6 +35,9 @@ include(GNUInstallDirs)
|
||||
set(CUTLASS_BUILD_MONO_LIBRARY OFF CACHE BOOL
|
||||
"Determines whether the cutlass library is generated as a single file or multiple files.")
|
||||
|
||||
option(CUTLASS_BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||
option(CUTLASS_BUILD_STATIC_LIBS "Build static libraries" ON)
|
||||
|
||||
################################################################################
|
||||
|
||||
add_library(cutlass_library_includes INTERFACE)
|
||||
@ -62,7 +65,7 @@ install(
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
add_library(cutlass_library_internal_interface INTERFACE)
|
||||
@ -123,88 +126,98 @@ function(cutlass_add_cutlass_library)
|
||||
if (CUTLASS_BUILD_MONO_LIBRARY AND __SUFFIX)
|
||||
|
||||
# If we're only building a single monolithic library then we
|
||||
# simply link the generated object files to the default library.
|
||||
# simply link the generated object files to the default library.
|
||||
if(CUTLASS_BUILD_SHARED_LIBS)
|
||||
target_link_libraries(${DEFAULT_NAME} PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${DEFAULT_NAME} PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
|
||||
target_link_libraries(${DEFAULT_NAME}_static PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
|
||||
if(CUTLASS_BUILD_STATIC_LIBS)
|
||||
target_link_libraries(${DEFAULT_NAME}_static PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
|
||||
endif()
|
||||
|
||||
else()
|
||||
|
||||
cutlass_add_library(
|
||||
${__NAME}
|
||||
SHARED
|
||||
EXPORT_NAME ${__EXPORT_NAME}
|
||||
""
|
||||
# Shared library (honors CMake's standard CUTLASS_BUILD_SHARED_LIBS)
|
||||
if(CUTLASS_BUILD_SHARED_LIBS)
|
||||
cutlass_add_library(
|
||||
${__NAME}
|
||||
SHARED
|
||||
EXPORT_NAME ${__EXPORT_NAME}
|
||||
""
|
||||
)
|
||||
|
||||
target_compile_features(${__NAME} INTERFACE cxx_std_17)
|
||||
|
||||
set_target_properties(
|
||||
${__NAME}
|
||||
PROPERTIES
|
||||
OUTPUT_NAME ${__OUTPUT_NAME}
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
${__NAME}
|
||||
PUBLIC cutlass_library_includes
|
||||
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
|
||||
cuda_driver
|
||||
)
|
||||
|
||||
set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}")
|
||||
|
||||
cutlass_add_library(
|
||||
${__NAME}_static
|
||||
STATIC
|
||||
EXPORT_NAME ${__EXPORT_NAME}_static
|
||||
""
|
||||
target_compile_features(${__NAME} INTERFACE cxx_std_17)
|
||||
|
||||
set_target_properties(${__NAME}
|
||||
PROPERTIES
|
||||
OUTPUT_NAME ${__OUTPUT_NAME}
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}"
|
||||
)
|
||||
|
||||
target_compile_features(${__NAME}_static INTERFACE cxx_std_17)
|
||||
|
||||
if (WIN32)
|
||||
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static)
|
||||
else()
|
||||
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME})
|
||||
endif()
|
||||
|
||||
set_target_properties(
|
||||
${__NAME}_static
|
||||
PROPERTIES
|
||||
OUTPUT_NAME ${STATIC_OUTPUT_NAME}
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
target_link_libraries(${__NAME}
|
||||
PUBLIC cutlass_library_includes
|
||||
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
|
||||
cuda_driver
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
${__NAME}_static
|
||||
PUBLIC cutlass_library_includes
|
||||
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
|
||||
cuda_driver
|
||||
|
||||
install(
|
||||
TARGETS ${__NAME}
|
||||
EXPORT NvidiaCutlass
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
set_target_properties(${__NAME}_static PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}")
|
||||
|
||||
install(
|
||||
TARGETS ${__NAME} ${__NAME}_static
|
||||
EXPORT NvidiaCutlass
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
if (__SUFFIX)
|
||||
|
||||
# The partial libraries generated will be registered as linked libraries
|
||||
# to the main cutlass library so users automatically get the necessary link
|
||||
# commands to pull in all kernels by default.
|
||||
|
||||
target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME})
|
||||
target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static)
|
||||
|
||||
|
||||
if (__SUFFIX)
|
||||
target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Static library
|
||||
if(CUTLASS_BUILD_STATIC_LIBS)
|
||||
cutlass_add_library(
|
||||
${__NAME}_static
|
||||
STATIC
|
||||
EXPORT_NAME ${__EXPORT_NAME}_static
|
||||
""
|
||||
)
|
||||
|
||||
target_compile_features(${__NAME}_static INTERFACE cxx_std_17)
|
||||
|
||||
if (WIN32)
|
||||
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static)
|
||||
else()
|
||||
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME})
|
||||
endif()
|
||||
|
||||
set_target_properties(
|
||||
${__NAME}_static
|
||||
PROPERTIES
|
||||
OUTPUT_NAME ${STATIC_OUTPUT_NAME}
|
||||
WINDOWS_EXPORT_ALL_SYMBOLS 1
|
||||
DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}"
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
${__NAME}_static
|
||||
PUBLIC cutlass_library_includes
|
||||
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
|
||||
cuda_driver
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS ${__NAME}_static
|
||||
EXPORT NvidiaCutlass
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
||||
if (__SUFFIX)
|
||||
target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
@ -268,8 +281,13 @@ cutlass_add_cutlass_library(
|
||||
)
|
||||
|
||||
# For backward compatibility with the old name
|
||||
add_library(cutlass_lib ALIAS cutlass_library)
|
||||
add_library(cutlass_lib_static ALIAS cutlass_library_static)
|
||||
if(CUTLASS_BUILD_SHARED_LIBS)
|
||||
add_library(cutlass_lib ALIAS cutlass_library)
|
||||
endif()
|
||||
|
||||
if(CUTLASS_BUILD_STATIC_LIBS)
|
||||
add_library(cutlass_lib_static ALIAS cutlass_library_static)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
@ -348,6 +348,9 @@ struct BlockScaledGemmDescription : public OperationDescription {
|
||||
/// Describes the destination matrix
|
||||
TensorDescription D;
|
||||
|
||||
/// Describes the sparse meta matrices
|
||||
TensorDescription E;
|
||||
|
||||
/// Describes the SFA operand
|
||||
TensorDescription SFA;
|
||||
|
||||
|
||||
@ -392,10 +392,11 @@ struct BlockScaledGemmArguments {
|
||||
|
||||
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
|
||||
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
|
||||
|
||||
int device_index{0};
|
||||
bool use_pdl{false};
|
||||
};
|
||||
|
||||
|
||||
/// Blockwise GEMM
|
||||
//
|
||||
// OperationKind: kBlockwiseGemm
|
||||
|
||||
@ -209,6 +209,7 @@ enum class GemmKind {
|
||||
kPlanarComplex,
|
||||
kPlanarComplexArray,
|
||||
kGrouped,
|
||||
kBlockScaledSparseGemm,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
|
||||
@ -41,6 +41,8 @@
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "gemm_operation_3x.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::library {
|
||||
@ -48,7 +50,7 @@ namespace cutlass::library {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
|
||||
class BlockScaledGemmUniversal3xOperationBase : public GemmOperation3xBase<Operator_> {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
@ -92,15 +94,9 @@ public:
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA;
|
||||
using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB;
|
||||
|
||||
|
||||
private:
|
||||
BlockScaledGemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
BlockScaledGemmUniversal3xOperationBase(char const *name = "unknown_gemm"):
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
|
||||
description_.kind = OperationKind::kBlockScaledGemm;
|
||||
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
|
||||
@ -182,38 +178,14 @@ public:
|
||||
BlockScaledGemmDescription const& get_gemm_description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
|
||||
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
|
||||
// Do nothing here and construct kernel arguments in update_arguments_ instead
|
||||
// We also cannot construct TMA descriptors without all the arguments available
|
||||
|
||||
operator_args.mode = configuration->mode;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
|
||||
BlockScaledGemmDescription description_;
|
||||
template <typename FusionArgs>
|
||||
static Status update_fusion_args(FusionArgs& fusion_args, BlockScaledGemmArguments const& arguments) {
|
||||
if constexpr (epilogue_scalefactor_generation) {
|
||||
fusion_args.block_scale_factor_ptr = static_cast<ElementSFD*>(arguments.SFD);
|
||||
fusion_args.norm_constant_ptr = static_cast<ElementCompute const *>(arguments.norm_constant);
|
||||
}
|
||||
|
||||
|
||||
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
|
||||
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
|
||||
@ -234,21 +206,12 @@ protected:
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
BlockScaledGemmArguments const *arguments) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread, *arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
static Status update_arguments_base(
|
||||
OperatorArguments& operator_args,
|
||||
BlockScaledGemmArguments const* arguments) {
|
||||
operator_args.problem_shape = cute::make_shape(
|
||||
arguments->problem_size.m(),
|
||||
arguments->problem_size.n(),
|
||||
@ -256,11 +219,10 @@ protected:
|
||||
arguments->batch_count);
|
||||
|
||||
// update arguments
|
||||
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
|
||||
|
||||
using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA;
|
||||
@ -298,17 +260,12 @@ protected:
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
|
||||
}
|
||||
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
|
||||
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
|
||||
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
arguments->lda, arguments->batch_stride_A);
|
||||
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
|
||||
arguments->ldb, arguments->batch_stride_B);
|
||||
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
|
||||
@ -353,7 +310,74 @@ protected:
|
||||
arguments->cluster_shape_fallback.n(),
|
||||
arguments->cluster_shape_fallback.k());
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Operator_>
|
||||
class BlockScaledGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase<Operator_> {
|
||||
public:
|
||||
using Base = BlockScaledGemmUniversal3xOperationBase<Operator_>;
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
BlockScaledGemmUniversal3xOperationBase<Operator_>(name) {}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
|
||||
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
|
||||
// Do nothing here and construct kernel arguments in update_arguments_ instead
|
||||
// We also cannot construct TMA descriptors without all the arguments available
|
||||
|
||||
operator_args.mode = configuration->mode;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
|
||||
return Base::update_fusion_args(fusion_args, arguments);
|
||||
}
|
||||
};
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
BlockScaledGemmArguments const *arguments) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread, *arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
if constexpr (Base::IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
|
||||
|
||||
} else {
|
||||
operator_args.mainloop.ptr_A = static_cast<typename Base::ElementA const *>(arguments->A);
|
||||
}
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
arguments->lda, arguments->batch_stride_A);
|
||||
status = Base::update_arguments_base(operator_args, arguments);
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -443,6 +467,306 @@ public:
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Operator_>
|
||||
class BlockScaledSparseGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase<Operator_> {
|
||||
public:
|
||||
using Base = BlockScaledGemmUniversal3xOperationBase<Operator_>;
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
using ArchTag = typename Operator::ArchTag;
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<typename Base::LayoutA>;
|
||||
using ElementE = typename Operator::CollectiveMainloop::ElementE;
|
||||
using LayoutE = typename Operator::CollectiveMainloop::LayoutE;
|
||||
using SparseConfig = typename Operator::CollectiveMainloop::SparseConfig;
|
||||
using ProblemShape = typename Operator::GemmKernel::ProblemShape;
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
typename Base::ElementA,
|
||||
typename Base::LayoutA,
|
||||
SparseConfig>;
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
typename Base::ElementA,
|
||||
typename Base::LayoutA,
|
||||
SparseConfig,
|
||||
ArchTag>;
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
private:
|
||||
// Variables that must change in the const functions.
|
||||
mutable CompressorUtility compressor_utility;
|
||||
mutable int problem_count = 1;
|
||||
mutable std::vector<int> iter_idx;
|
||||
|
||||
mutable uint64_t tensor_ac_size = 0;
|
||||
mutable uint64_t tensor_e_size = 0;
|
||||
mutable uint64_t tensor_a_size = 0;
|
||||
mutable uint64_t host_op_workspace_size = 0;
|
||||
mutable uint64_t device_compress_workspace_size = 0;
|
||||
mutable uint64_t device_op_workspace_size = 0;
|
||||
mutable uint64_t device_per_iter_workspace_size = 0;
|
||||
|
||||
public:
|
||||
BlockScaledSparseGemmUniversal3xOperation(char const *name = "unknown_gemm"):Base(name) {
|
||||
this->description_.E = make_TensorDescription<ElementE, typename Base::LayoutA>(typename SparseConfig::TensorEAlignmentK{});
|
||||
}
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
|
||||
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
|
||||
// Do nothing here and construct kernel arguments in update_arguments_ instead
|
||||
// We also cannot construct TMA descriptors without all the arguments available
|
||||
|
||||
operator_args.mode = configuration->mode;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
|
||||
return Base::update_fusion_args(fusion_args, arguments);
|
||||
}
|
||||
};
|
||||
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
BlockScaledGemmArguments const *arguments,
|
||||
CompressorUtility const& compressor_utility,
|
||||
void* device_a_compressed_ptr = nullptr,
|
||||
void* device_e_ptr = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread, *arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// update arguments
|
||||
if constexpr (Base::IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(device_a_compressed_ptr);
|
||||
} else {
|
||||
operator_args.mainloop.ptr_A = static_cast<typename Base::ElementA const *>(device_a_compressed_ptr);
|
||||
}
|
||||
operator_args.mainloop.ptr_E = static_cast<ElementE const *>(device_e_ptr);
|
||||
|
||||
operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor();
|
||||
operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor();
|
||||
|
||||
status = Base::update_arguments_base(operator_args, arguments);
|
||||
return status;
|
||||
}
|
||||
public:
|
||||
/// Gets the device-side workspace
|
||||
uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr,void const *arguments_ptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(
|
||||
args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr), compressor_utility);
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
typename Compressor::Arguments compress_arguments {
|
||||
{compressor_utility.M, 0, compressor_utility.K, compressor_utility.L},
|
||||
{/*Empty Not Use*/},
|
||||
{/*Empty Not Use*/} };
|
||||
|
||||
// Size for one iteration
|
||||
// For multi-iteration, will need to multiply result of this function w/ actual problem_count
|
||||
tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes();
|
||||
tensor_e_size = compressor_utility.get_tensor_E_bytes();
|
||||
device_op_workspace_size = Operator::get_workspace_size(args);
|
||||
device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments);
|
||||
// NOTE: order here is the order of workspace partition
|
||||
device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size;
|
||||
|
||||
return device_per_iter_workspace_size;
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
uint64_t get_host_workspace_size(void const *configuration) const override {
|
||||
// Memory to hold operator
|
||||
host_op_workspace_size = sizeof(Operator);
|
||||
|
||||
// Memory to hold result of `.structure_sparse_zero_mask_fill()`
|
||||
tensor_a_size = compressor_utility.get_raw_tensor_A_bytes();
|
||||
|
||||
// NOTE: order here is the order of workspace partition
|
||||
const uint64_t size = host_op_workspace_size + tensor_a_size;
|
||||
return size;
|
||||
}
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement(
|
||||
void const *configuration_ptr, void const *arguments_ptr) const override {
|
||||
|
||||
GemmUniversalConfiguration const *configuration =
|
||||
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
|
||||
BlockScaledGemmArguments const *arguments =
|
||||
static_cast<BlockScaledGemmArguments const *>(arguments_ptr);
|
||||
|
||||
OperatorArguments args;
|
||||
auto problem_shape_MNKL = cute::make_shape(
|
||||
configuration->problem_size.m(),
|
||||
configuration->problem_size.n(),
|
||||
configuration->problem_size.k(),
|
||||
configuration->batch_count);
|
||||
const int M = configuration->problem_size.m();
|
||||
const int N = configuration->problem_size.n();
|
||||
const int K = configuration->problem_size.k();
|
||||
const int L = configuration->batch_count;
|
||||
using StrideA = typename CompressorUtility::StrideA;
|
||||
auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
compressor_utility.set_problem_size(problem_shape_MNKL, dA);
|
||||
|
||||
auto status = update_arguments_(args, arguments, compressor_utility);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
// can_implement rules may need access to problem shape
|
||||
args.problem_shape = problem_shape_MNKL;
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
Status initialize_with_profiler_workspace(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
uint8_t **profiler_workspaces,
|
||||
int problem_count_from_profiler,
|
||||
cudaStream_t stream = nullptr) {
|
||||
iter_idx.resize(static_cast<GemmUniversalConfiguration const*>(configuration)->device_count, 0);
|
||||
// Set problem_count.
|
||||
problem_count = problem_count_from_profiler;
|
||||
|
||||
// * Host Ptr
|
||||
auto* host_op_workspace_ptr = reinterpret_cast<uint8_t*>(host_workspace);
|
||||
auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size;
|
||||
|
||||
// * Construct Op
|
||||
Operator *op = new (host_op_workspace_ptr) Operator;
|
||||
|
||||
// * Device Ptr (1st iteration)
|
||||
// Device workspace : | iter1 | iter2 | iter3 | .. | iterx |
|
||||
// iteri : op_workspace | tensor_ac | tensor_e
|
||||
auto* device_ptr_iter1 = static_cast<uint8_t*>(device_workspace);
|
||||
auto* device_op_workspace_ptr_iter1 = device_ptr_iter1;
|
||||
auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size;
|
||||
auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size;
|
||||
auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size;
|
||||
// * Device A Raw Ptr
|
||||
auto* device_a_raw_ptr = profiler_workspaces[0];
|
||||
// * Random fill 50% of TensorA w/ zero following the structured sparse requirement
|
||||
CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream));
|
||||
compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000);
|
||||
CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream));
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// * Compress DTensorA and get DTensorAC & DTensorE
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
CUDA_CHECK(cudaGetDevice(&hw_info.device_id));
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
{compressor_utility.M, 0, compressor_utility.K, compressor_utility.L},
|
||||
{device_a_raw_ptr,
|
||||
compressor_utility.dA,
|
||||
device_a_compressed_ptr_iter1,
|
||||
device_e_ptr_iter1},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
cutlass::Status status {cutlass::Status::kSuccess};
|
||||
|
||||
Compressor compressor_op;
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = compressor_op.run(stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE
|
||||
for (int iter_i = 1; iter_i < problem_count; iter_i++) {
|
||||
// * Device AC E Ptr per iteration
|
||||
// Device workspace : | iter1 | iter2 | iter3 | .. | iterx |
|
||||
// iteri : op_workspace | tensor_ac | tensor_e
|
||||
auto* device_ptr_iteri = static_cast<uint8_t*>(device_workspace) + device_per_iter_workspace_size * iter_i;
|
||||
auto* device_op_workspace_ptr = device_ptr_iteri;
|
||||
auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size;
|
||||
auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size;
|
||||
auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size;
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return Status::kSuccess;
|
||||
}
|
||||
/// Runs the kernel
|
||||
Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments operator_args;
|
||||
const auto device_index = static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->device_index;
|
||||
|
||||
auto* device_ptr_iteri = static_cast<uint8_t*>(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index];
|
||||
auto* device_op_workspace_ptr = device_ptr_iteri;
|
||||
auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size;
|
||||
auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size;
|
||||
auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size;
|
||||
iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count;
|
||||
|
||||
Status status = update_arguments_(operator_args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
|
||||
status = op->run(operator_args, device_workspace, stream, nullptr, static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->use_pdl);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::library
|
||||
|
||||
@ -185,23 +185,11 @@ public:
|
||||
|
||||
/// Constructor
|
||||
GemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
|
||||
dim3 cluster_dims(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
}
|
||||
}
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {}
|
||||
|
||||
private:
|
||||
int max_active_clusters{};
|
||||
// mutable because it needs to be set in initialize (see comment in initialize)
|
||||
mutable int max_active_clusters{};
|
||||
|
||||
protected:
|
||||
|
||||
@ -683,6 +671,21 @@ public:
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
// this would ideally go in the constructor, but
|
||||
// the constructor is called at profiler startup for EVERY kernel,
|
||||
// REGARDLESS of whether the kernel is actually supported on the device
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
|
||||
dim3 cluster_dims(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
}
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -99,6 +99,46 @@ void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) {
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -104,11 +104,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
else()
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
if (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED))
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
else()
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --mode=trace --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
endif()
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true)
|
||||
|
||||
@ -355,11 +355,17 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse(
|
||||
int64_t BlockScaledGemmOperationProfiler::GemmProblem::bytes_with_problem_shape(
|
||||
library::BlockScaledGemmDescription const &operation_desc,
|
||||
gemm::GemmCoord const &problem_shape) const {
|
||||
|
||||
int sfa_m = round_up(problem_shape.m(), 128);
|
||||
int sfb_n = round_up(problem_shape.n(), 128);
|
||||
int sfa_sfb_k = round_up(ceil_div(problem_shape.k(), operation_desc.SFVecSize), 4);
|
||||
// Input bytes read and Output bytes written for the gemm problem
|
||||
int64_t bytes =
|
||||
int64_t(library::sizeof_bits(operation_desc.A.element) * problem_shape.m() / 8) * problem_shape.k() +
|
||||
int64_t(library::sizeof_bits(operation_desc.B.element) * problem_shape.n() / 8) * problem_shape.k() +
|
||||
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n();
|
||||
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n() +
|
||||
int64_t(library::sizeof_bits(operation_desc.SFA.element) * sfa_m / 8) * sfa_sfb_k +
|
||||
int64_t(library::sizeof_bits(operation_desc.SFB.element) * sfb_n / 8) * sfa_sfb_k;
|
||||
|
||||
// Set is_beta_zero true if beta is zero
|
||||
bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; });
|
||||
@ -726,6 +732,8 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
|
||||
library::BlockScaledGemmDescription const &operation_desc =
|
||||
static_cast<library::BlockScaledGemmDescription const &>(operation->description());
|
||||
|
||||
bool is_sparse = operation_desc.tile_description.math_instruction.opcode_class == cutlass::library::OpcodeClassID::kSparseTensorOp;
|
||||
|
||||
// Compute the number of copies of the problem to avoid L2 camping.
|
||||
if (!options.profiling.workspace_count) {
|
||||
int64_t bytes = problem_.bytes(operation_desc);
|
||||
@ -917,6 +925,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
|
||||
gemm_workspace_.arguments.device_index = static_cast<int>(0);
|
||||
}
|
||||
|
||||
//
|
||||
@ -932,12 +941,34 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
|
||||
|
||||
workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration,
|
||||
&gemm_workspace_.arguments);
|
||||
if (is_sparse) {
|
||||
// sparse gemm get_device_workspace_size() only return device workspace size per iteration
|
||||
// Needs to multiply it w/ number of iteration
|
||||
workspace_size *= gemm_workspace_.problem_count;
|
||||
}
|
||||
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
|
||||
|
||||
// Convert to structure sparse contents here.
|
||||
if (is_sparse) {
|
||||
uint8_t* profiler_workspaces[1];
|
||||
profiler_workspaces[0] = reinterpret_cast<uint8_t*>(gemm_workspace_.A->data());
|
||||
// Sparse operations have a different initialize interface.
|
||||
// initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE
|
||||
auto modifiable_underlying_op = const_cast<library::Operation*>(underlying_operation);
|
||||
status = modifiable_underlying_op->initialize_with_profiler_workspace(
|
||||
&gemm_workspace_.configuration,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data(),
|
||||
profiler_workspaces,
|
||||
gemm_workspace_.problem_count);
|
||||
}
|
||||
else {
|
||||
status = underlying_operation->initialize(
|
||||
&gemm_workspace_.configuration,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
}
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ CutlassProfiler::CutlassProfiler(
|
||||
|
||||
operation_profilers_.emplace_back(new GemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
|
||||
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));
|
||||
|
||||
|
||||
@ -45,8 +45,8 @@ target_link_libraries(
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
|
||||
)
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS cutlass_tools_util_includes
|
||||
|
||||
Reference in New Issue
Block a user