v4.3 update. (#2709)

* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
Junkai-Wu
2025-10-22 02:26:30 +08:00
committed by GitHub
parent e6e2cc29f5
commit b1d6e2c9b3
244 changed files with 59272 additions and 10455 deletions

View File

@ -35,6 +35,9 @@ include(GNUInstallDirs)
set(CUTLASS_BUILD_MONO_LIBRARY OFF CACHE BOOL
"Determines whether the cutlass library is generated as a single file or multiple files.")
option(CUTLASS_BUILD_SHARED_LIBS "Build shared libraries" ON)
option(CUTLASS_BUILD_STATIC_LIBS "Build static libraries" ON)
################################################################################
add_library(cutlass_library_includes INTERFACE)
@ -62,7 +65,7 @@ install(
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
add_library(cutlass_library_internal_interface INTERFACE)
@ -123,88 +126,98 @@ function(cutlass_add_cutlass_library)
if (CUTLASS_BUILD_MONO_LIBRARY AND __SUFFIX)
# If we're only building a single monolithic library then we
# simply link the generated object files to the default library.
# simply link the generated object files to the default library.
if(CUTLASS_BUILD_SHARED_LIBS)
target_link_libraries(${DEFAULT_NAME} PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
endif()
target_link_libraries(${DEFAULT_NAME} PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
target_link_libraries(${DEFAULT_NAME}_static PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
if(CUTLASS_BUILD_STATIC_LIBS)
target_link_libraries(${DEFAULT_NAME}_static PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>)
endif()
else()
cutlass_add_library(
${__NAME}
SHARED
EXPORT_NAME ${__EXPORT_NAME}
""
# Shared library (honors CMake's standard CUTLASS_BUILD_SHARED_LIBS)
if(CUTLASS_BUILD_SHARED_LIBS)
cutlass_add_library(
${__NAME}
SHARED
EXPORT_NAME ${__EXPORT_NAME}
""
)
target_compile_features(${__NAME} INTERFACE cxx_std_17)
set_target_properties(
${__NAME}
PROPERTIES
OUTPUT_NAME ${__OUTPUT_NAME}
WINDOWS_EXPORT_ALL_SYMBOLS 1
)
target_link_libraries(
${__NAME}
PUBLIC cutlass_library_includes
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
cuda_driver
)
set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}")
cutlass_add_library(
${__NAME}_static
STATIC
EXPORT_NAME ${__EXPORT_NAME}_static
""
target_compile_features(${__NAME} INTERFACE cxx_std_17)
set_target_properties(${__NAME}
PROPERTIES
OUTPUT_NAME ${__OUTPUT_NAME}
WINDOWS_EXPORT_ALL_SYMBOLS 1
DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}"
)
target_compile_features(${__NAME}_static INTERFACE cxx_std_17)
if (WIN32)
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static)
else()
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME})
endif()
set_target_properties(
${__NAME}_static
PROPERTIES
OUTPUT_NAME ${STATIC_OUTPUT_NAME}
WINDOWS_EXPORT_ALL_SYMBOLS 1
target_link_libraries(${__NAME}
PUBLIC cutlass_library_includes
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
cuda_driver
)
target_link_libraries(
${__NAME}_static
PUBLIC cutlass_library_includes
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
cuda_driver
install(
TARGETS ${__NAME}
EXPORT NvidiaCutlass
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
set_target_properties(${__NAME}_static PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}")
install(
TARGETS ${__NAME} ${__NAME}_static
EXPORT NvidiaCutlass
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
if (__SUFFIX)
# The partial libraries generated will be registered as linked libraries
# to the main cutlass library so users automatically get the necessary link
# commands to pull in all kernels by default.
target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME})
target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static)
if (__SUFFIX)
target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME})
endif()
endif()
# Static library
if(CUTLASS_BUILD_STATIC_LIBS)
cutlass_add_library(
${__NAME}_static
STATIC
EXPORT_NAME ${__EXPORT_NAME}_static
""
)
target_compile_features(${__NAME}_static INTERFACE cxx_std_17)
if (WIN32)
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static)
else()
set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME})
endif()
set_target_properties(
${__NAME}_static
PROPERTIES
OUTPUT_NAME ${STATIC_OUTPUT_NAME}
WINDOWS_EXPORT_ALL_SYMBOLS 1
DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}"
)
target_link_libraries(
${__NAME}_static
PUBLIC cutlass_library_includes
PRIVATE $<BUILD_INTERFACE:${__NAME}_objs>
cuda_driver
)
install(
TARGETS ${__NAME}_static
EXPORT NvidiaCutlass
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
if (__SUFFIX)
target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static)
endif()
endif()
endif()
endfunction()
@ -268,8 +281,13 @@ cutlass_add_cutlass_library(
)
# For backward compatibility with the old name
add_library(cutlass_lib ALIAS cutlass_library)
add_library(cutlass_lib_static ALIAS cutlass_library_static)
if(CUTLASS_BUILD_SHARED_LIBS)
add_library(cutlass_lib ALIAS cutlass_library)
endif()
if(CUTLASS_BUILD_STATIC_LIBS)
add_library(cutlass_lib_static ALIAS cutlass_library_static)
endif()
################################################################################

View File

@ -348,6 +348,9 @@ struct BlockScaledGemmDescription : public OperationDescription {
/// Describes the destination matrix
TensorDescription D;
/// Describes the sparse meta matrices
TensorDescription E;
/// Describes the SFA operand
TensorDescription SFA;

View File

@ -392,10 +392,11 @@ struct BlockScaledGemmArguments {
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
int device_index{0};
bool use_pdl{false};
};
/// Blockwise GEMM
//
// OperationKind: kBlockwiseGemm

View File

@ -209,6 +209,7 @@ enum class GemmKind {
kPlanarComplex,
kPlanarComplexArray,
kGrouped,
kBlockScaledSparseGemm,
kInvalid
};

View File

@ -41,6 +41,8 @@
#include "cutlass/library/library.h"
#include "library_internal.h"
#include "gemm_operation_3x.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::library {
@ -48,7 +50,7 @@ namespace cutlass::library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
class BlockScaledGemmUniversal3xOperationBase : public GemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
@ -92,15 +94,9 @@ public:
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA;
using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB;
private:
BlockScaledGemmDescription description_;
public:
/// Constructor
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
BlockScaledGemmUniversal3xOperationBase(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
description_.kind = OperationKind::kBlockScaledGemm;
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
@ -182,38 +178,14 @@ public:
BlockScaledGemmDescription const& get_gemm_description() const {
return description_;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
BlockScaledGemmDescription description_;
template <typename FusionArgs>
static Status update_fusion_args(FusionArgs& fusion_args, BlockScaledGemmArguments const& arguments) {
if constexpr (epilogue_scalefactor_generation) {
fusion_args.block_scale_factor_ptr = static_cast<ElementSFD*>(arguments.SFD);
fusion_args.norm_constant_ptr = static_cast<ElementCompute const *>(arguments.norm_constant);
}
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
@ -234,21 +206,12 @@ protected:
else {
return Status::kErrorInvalidProblem;
}
}
};
}
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
BlockScaledGemmArguments const *arguments) {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
static Status update_arguments_base(
OperatorArguments& operator_args,
BlockScaledGemmArguments const* arguments) {
operator_args.problem_shape = cute::make_shape(
arguments->problem_size.m(),
arguments->problem_size.n(),
@ -256,11 +219,10 @@ protected:
arguments->batch_count);
// update arguments
if constexpr (IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA;
@ -298,17 +260,12 @@ protected:
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
}
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
arguments->ldb, arguments->batch_stride_B);
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
@ -353,7 +310,74 @@ protected:
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return Status::kSuccess;
}
};
template <typename Operator_>
class BlockScaledGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase<Operator_> {
public:
using Base = BlockScaledGemmUniversal3xOperationBase<Operator_>;
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
public:
/// Constructor
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
BlockScaledGemmUniversal3xOperationBase<Operator_>(name) {}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
return Base::update_fusion_args(fusion_args, arguments);
}
};
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
BlockScaledGemmArguments const *arguments) {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
if constexpr (Base::IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
} else {
operator_args.mainloop.ptr_A = static_cast<typename Base::ElementA const *>(arguments->A);
}
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
status = Base::update_arguments_base(operator_args, arguments);
return status;
}
@ -443,6 +467,306 @@ public:
return status;
}
};
template <typename Operator_>
class BlockScaledSparseGemmUniversal3xOperation : public BlockScaledGemmUniversal3xOperationBase<Operator_> {
public:
using Base = BlockScaledGemmUniversal3xOperationBase<Operator_>;
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ArchTag = typename Operator::ArchTag;
using StrideA = cutlass::gemm::TagToStrideA_t<typename Base::LayoutA>;
using ElementE = typename Operator::CollectiveMainloop::ElementE;
using LayoutE = typename Operator::CollectiveMainloop::LayoutE;
using SparseConfig = typename Operator::CollectiveMainloop::SparseConfig;
using ProblemShape = typename Operator::GemmKernel::ProblemShape;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape,
typename Base::ElementA,
typename Base::LayoutA,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape,
typename Base::ElementA,
typename Base::LayoutA,
SparseConfig,
ArchTag>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
private:
// Variables that must change in the const functions.
mutable CompressorUtility compressor_utility;
mutable int problem_count = 1;
mutable std::vector<int> iter_idx;
mutable uint64_t tensor_ac_size = 0;
mutable uint64_t tensor_e_size = 0;
mutable uint64_t tensor_a_size = 0;
mutable uint64_t host_op_workspace_size = 0;
mutable uint64_t device_compress_workspace_size = 0;
mutable uint64_t device_op_workspace_size = 0;
mutable uint64_t device_per_iter_workspace_size = 0;
public:
BlockScaledSparseGemmUniversal3xOperation(char const *name = "unknown_gemm"):Base(name) {
this->description_.E = make_TensorDescription<ElementE, typename Base::LayoutA>(typename SparseConfig::TensorEAlignmentK{});
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
return Base::update_fusion_args(fusion_args, arguments);
}
};
static Status update_arguments_(
OperatorArguments &operator_args,
BlockScaledGemmArguments const *arguments,
CompressorUtility const& compressor_utility,
void* device_a_compressed_ptr = nullptr,
void* device_e_ptr = nullptr) {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
// update arguments
if constexpr (Base::IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(device_a_compressed_ptr);
} else {
operator_args.mainloop.ptr_A = static_cast<typename Base::ElementA const *>(device_a_compressed_ptr);
}
operator_args.mainloop.ptr_E = static_cast<ElementE const *>(device_e_ptr);
operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor();
operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor();
status = Base::update_arguments_base(operator_args, arguments);
return status;
}
public:
/// Gets the device-side workspace
uint64_t get_device_workspace_size(
void const *configuration_ptr,void const *arguments_ptr) const override {
OperatorArguments args;
auto status = update_arguments_(
args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr), compressor_utility);
if (status != Status::kSuccess) {
return 0;
}
typename Compressor::Arguments compress_arguments {
{compressor_utility.M, 0, compressor_utility.K, compressor_utility.L},
{/*Empty Not Use*/},
{/*Empty Not Use*/} };
// Size for one iteration
// For multi-iteration, will need to multiply result of this function w/ actual problem_count
tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes();
tensor_e_size = compressor_utility.get_tensor_E_bytes();
device_op_workspace_size = Operator::get_workspace_size(args);
device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments);
// NOTE: order here is the order of workspace partition
device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size;
return device_per_iter_workspace_size;
}
/// Gets the host-side workspace
uint64_t get_host_workspace_size(void const *configuration) const override {
// Memory to hold operator
host_op_workspace_size = sizeof(Operator);
// Memory to hold result of `.structure_sparse_zero_mask_fill()`
tensor_a_size = compressor_utility.get_raw_tensor_A_bytes();
// NOTE: order here is the order of workspace partition
const uint64_t size = host_op_workspace_size + tensor_a_size;
return size;
}
/// Returns success if the operation can proceed
Status can_implement(
void const *configuration_ptr, void const *arguments_ptr) const override {
GemmUniversalConfiguration const *configuration =
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
BlockScaledGemmArguments const *arguments =
static_cast<BlockScaledGemmArguments const *>(arguments_ptr);
OperatorArguments args;
auto problem_shape_MNKL = cute::make_shape(
configuration->problem_size.m(),
configuration->problem_size.n(),
configuration->problem_size.k(),
configuration->batch_count);
const int M = configuration->problem_size.m();
const int N = configuration->problem_size.n();
const int K = configuration->problem_size.k();
const int L = configuration->batch_count;
using StrideA = typename CompressorUtility::StrideA;
auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
compressor_utility.set_problem_size(problem_shape_MNKL, dA);
auto status = update_arguments_(args, arguments, compressor_utility);
if (status != Status::kSuccess) {
return status;
}
// can_implement rules may need access to problem shape
args.problem_shape = problem_shape_MNKL;
return Operator::can_implement(args);
}
/// Initializes the workspace
Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const override {
Operator *op = new (host_workspace) Operator;
return Status::kSuccess;
}
Status initialize_with_profiler_workspace(
void const *configuration,
void *host_workspace,
void *device_workspace,
uint8_t **profiler_workspaces,
int problem_count_from_profiler,
cudaStream_t stream = nullptr) {
iter_idx.resize(static_cast<GemmUniversalConfiguration const*>(configuration)->device_count, 0);
// Set problem_count.
problem_count = problem_count_from_profiler;
// * Host Ptr
auto* host_op_workspace_ptr = reinterpret_cast<uint8_t*>(host_workspace);
auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size;
// * Construct Op
Operator *op = new (host_op_workspace_ptr) Operator;
// * Device Ptr (1st iteration)
// Device workspace : | iter1 | iter2 | iter3 | .. | iterx |
// iteri : op_workspace | tensor_ac | tensor_e
auto* device_ptr_iter1 = static_cast<uint8_t*>(device_workspace);
auto* device_op_workspace_ptr_iter1 = device_ptr_iter1;
auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size;
auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size;
auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size;
// * Device A Raw Ptr
auto* device_a_raw_ptr = profiler_workspaces[0];
// * Random fill 50% of TensorA w/ zero following the structured sparse requirement
CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream));
compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000);
CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaGetLastError());
// * Compress DTensorA and get DTensorAC & DTensorE
cutlass::KernelHardwareInfo hw_info;
CUDA_CHECK(cudaGetDevice(&hw_info.device_id));
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{compressor_utility.M, 0, compressor_utility.K, compressor_utility.L},
{device_a_raw_ptr,
compressor_utility.dA,
device_a_compressed_ptr_iter1,
device_e_ptr_iter1},
{hw_info}
};
cutlass::Status status {cutlass::Status::kSuccess};
Compressor compressor_op;
status = compressor_op.can_implement(arguments);
if (status != Status::kSuccess) {
return status;
}
status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream);
if (status != Status::kSuccess) {
return status;
}
status = compressor_op.run(stream);
if (status != Status::kSuccess) {
return status;
}
// * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE
for (int iter_i = 1; iter_i < problem_count; iter_i++) {
// * Device AC E Ptr per iteration
// Device workspace : | iter1 | iter2 | iter3 | .. | iterx |
// iteri : op_workspace | tensor_ac | tensor_e
auto* device_ptr_iteri = static_cast<uint8_t*>(device_workspace) + device_per_iter_workspace_size * iter_i;
auto* device_op_workspace_ptr = device_ptr_iteri;
auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size;
auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size;
auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size;
CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream));
}
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaGetLastError());
return Status::kSuccess;
}
/// Runs the kernel
Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments operator_args;
const auto device_index = static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->device_index;
auto* device_ptr_iteri = static_cast<uint8_t*>(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index];
auto* device_op_workspace_ptr = device_ptr_iteri;
auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size;
auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size;
auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size;
iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count;
Status status = update_arguments_(operator_args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr);
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
status = op->run(operator_args, device_workspace, stream, nullptr, static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->use_pdl);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::library

View File

@ -185,23 +185,11 @@ public:
/// Constructor
GemmUniversal3xOperation(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
}
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {}
private:
int max_active_clusters{};
// mutable because it needs to be set in initialize (see comment in initialize)
mutable int max_active_clusters{};
protected:
@ -683,6 +671,21 @@ public:
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const override {
// this would ideally go in the constructor, but
// the constructor is called at profiler startup for EVERY kernel,
// REGARDLESS of whether the kernel is actually supported on the device
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
Operator *op = new (host_workspace) Operator;
return Status::kSuccess;
}

View File

@ -99,6 +99,46 @@ void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) {
float // ElementD
>(manifest);
// 1.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e5m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e5m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -104,11 +104,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
else()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
if (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED))
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
else()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --mode=trace --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
endif()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
endif()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true)

View File

@ -355,11 +355,17 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse(
int64_t BlockScaledGemmOperationProfiler::GemmProblem::bytes_with_problem_shape(
library::BlockScaledGemmDescription const &operation_desc,
gemm::GemmCoord const &problem_shape) const {
int sfa_m = round_up(problem_shape.m(), 128);
int sfb_n = round_up(problem_shape.n(), 128);
int sfa_sfb_k = round_up(ceil_div(problem_shape.k(), operation_desc.SFVecSize), 4);
// Input bytes read and Output bytes written for the gemm problem
int64_t bytes =
int64_t(library::sizeof_bits(operation_desc.A.element) * problem_shape.m() / 8) * problem_shape.k() +
int64_t(library::sizeof_bits(operation_desc.B.element) * problem_shape.n() / 8) * problem_shape.k() +
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n();
int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n() +
int64_t(library::sizeof_bits(operation_desc.SFA.element) * sfa_m / 8) * sfa_sfb_k +
int64_t(library::sizeof_bits(operation_desc.SFB.element) * sfb_n / 8) * sfa_sfb_k;
// Set is_beta_zero true if beta is zero
bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; });
@ -726,6 +732,8 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
library::BlockScaledGemmDescription const &operation_desc =
static_cast<library::BlockScaledGemmDescription const &>(operation->description());
bool is_sparse = operation_desc.tile_description.math_instruction.opcode_class == cutlass::library::OpcodeClassID::kSparseTensorOp;
// Compute the number of copies of the problem to avoid L2 camping.
if (!options.profiling.workspace_count) {
int64_t bytes = problem_.bytes(operation_desc);
@ -917,6 +925,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
gemm_workspace_.arguments.device_index = static_cast<int>(0);
}
//
@ -932,12 +941,34 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration,
&gemm_workspace_.arguments);
if (is_sparse) {
// sparse gemm get_device_workspace_size() only return device workspace size per iteration
// Needs to multiply it w/ number of iteration
workspace_size *= gemm_workspace_.problem_count;
}
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
// Convert to structure sparse contents here.
if (is_sparse) {
uint8_t* profiler_workspaces[1];
profiler_workspaces[0] = reinterpret_cast<uint8_t*>(gemm_workspace_.A->data());
// Sparse operations have a different initialize interface.
// initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE
auto modifiable_underlying_op = const_cast<library::Operation*>(underlying_operation);
status = modifiable_underlying_op->initialize_with_profiler_workspace(
&gemm_workspace_.configuration,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data(),
profiler_workspaces,
gemm_workspace_.problem_count);
}
else {
status = underlying_operation->initialize(
&gemm_workspace_.configuration,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
}
if (status != Status::kSuccess) {
return status;
}

View File

@ -63,7 +63,7 @@ CutlassProfiler::CutlassProfiler(
operation_profilers_.emplace_back(new GemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));

View File

@ -45,8 +45,8 @@ target_link_libraries(
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/
)
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(
TARGETS cutlass_tools_util_includes