cutlass 3.9 update (#2255)

* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-24 12:42:40 -07:00
committed by GitHub
parent 8e345c5c5b
commit 331a1f5b3f
143 changed files with 18089 additions and 5935 deletions

View File

@ -437,9 +437,11 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "k", problem_space, k);
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@
// Profiler includes
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
#include "cutlass/profiler/blockwise_gemm_operation_profiler.h"
#include "cutlass/profiler/conv2d_operation_profiler.h"
#include "cutlass/profiler/conv3d_operation_profiler.h"
#include "cutlass/profiler/cutlass_profiler.h"
@ -64,6 +65,8 @@ CutlassProfiler::CutlassProfiler(
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));

View File

@ -440,10 +440,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "n", problem_space, n);
set_argument(result, "k", problem_space, k);
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);

View File

@ -40,6 +40,7 @@
#include <stdexcept>
#include <string>
#include <vector>
#include <regex>
#include <cuda_runtime_api.h>
@ -459,9 +460,11 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
set_argument(result, "problem-sizes", problem_space, ss.str());
}
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.gemm.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
@ -497,10 +500,22 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
// We distinguish between block scaled and non-block scaled operations by looking at the kernel
// name, which tells us what reference kernel to use, which arguments to pass to the operation
// etc. This avoids creating yet another OperationProfiler with a lot of boilerplate in it.
std::string sf_tuple = "\\d+x\\d+";
std::string datatypes_regex = "\\w?f\\d+|e\\dm\\d"; // bf16 | f16 | f32 | e4m3 | ...
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
datatypes_regex + ")_" + sf_tuple + "(" +
datatypes_regex + ")x(" + datatypes_regex + ")";
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
is_block_scaled = true;
gemm_workspace_.block_scales = BlockScalingWorkspace{};
}
else if (std::regex_search(operation_desc.gemm.name, std::regex(blockwise_regex_string))) {
is_blockwise = true;
gemm_workspace_.block_scales = BlockScalingWorkspace{};
}
else {
is_block_scaled = false;
gemm_workspace_.block_scales = std::nullopt;
@ -605,6 +620,12 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
block_scaling_ws.SFD_ptr_array_host.resize(num_groups);
block_scaling_ws.SFD_reference_ptr_array_host.resize(num_groups);
}
else if (is_blockwise) {
auto& block_scaling_ws = gemm_workspace_.block_scales.value();
block_scaling_ws.SFA_ptr_array_host.resize(num_groups);
block_scaling_ws.SFB_ptr_array_host.resize(num_groups);
block_scaling_ws.SFC_ptr_array_host.resize(num_groups);
}
static_assert(sizeof(void*) == 8); // allocating blocks for pointers, so verify pointer size
// ldx
gemm_workspace_.lda_array_device =
@ -698,7 +719,7 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
int sfa_m = round_up(int(problem_.m(group_idx)), 128);
int sfb_n = round_up(int(problem_.n(group_idx)), 128);
int sfa_sfb_k =
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFVecSize), 4);
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize), 4);
int sfd_m =
block_scale_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor
@ -760,6 +781,37 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
block_scale_ws.SFD_ptr_array_host[group_idx]->fill_device(0);
}
}
else if (is_blockwise) {
auto const block_scale_desc = operation_desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
int sfa_m = ceil_div(int(problem_.m(group_idx)), block_scale_desc.SFMVecSize);
int sfb_n = ceil_div(int(problem_.n(group_idx)), block_scale_desc.SFNVecSize);
int sfa_sfb_k = ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize);
block_scale_ws.SFA_ptr_array_host[group_idx] =
device_context.allocate_and_initialize_tensor(
options,
"SFA_" + std::to_string(group_idx),
block_scale_desc.SFA.element,
block_scale_desc.SFA.layout,
{sfa_m, sfa_sfb_k},
{sfa_m},
gemm_workspace_.problem_count,
seed_shift++,
0);
block_scale_ws.SFB_ptr_array_host[group_idx] =
device_context.allocate_and_initialize_tensor(
options,
"SFB_" + std::to_string(group_idx),
block_scale_desc.SFB.element,
block_scale_desc.SFB.layout,
{sfa_sfb_k, sfb_n},
{sfb_n},
gemm_workspace_.problem_count,
seed_shift++,
0);
}
}
// takes the allocated tensors and initializes an array of pointers per problem in the workspace
@ -825,6 +877,18 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
0 // device_index
);
}
else if (is_blockwise) {
auto& block_scale_ws = gemm_workspace_.block_scales.value();
create_dev_ptr_array_all_workspace(
block_scale_ws.SFA_ptr_array_device,
block_scale_ws.SFA_ptr_array_host,
"SFA");
create_dev_ptr_array_all_workspace(
block_scale_ws.SFB_ptr_array_device,
block_scale_ws.SFB_ptr_array_host,
"SFB");
}
init_arguments(options);
}
@ -896,6 +960,11 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
init_arguments(options);
library::Operation const* underlying_operation = operation;
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
if (results_.back().status != Status::kSuccess) {
return false;
}
results_.back().status = underlying_operation->run(
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
@ -998,7 +1067,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
}
// we only have a block scaled reference kernel implemented on the host
if (is_block_scaled && provider != library::Provider::kReferenceHost) {
if ((is_block_scaled || is_blockwise) && provider != library::Provider::kReferenceHost) {
continue;
}
@ -1064,12 +1133,22 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
ptr_norm_constant = host_data_norm_constant.data();
ws.norm_constant->copy_to_host(ptr_norm_constant);
}
else if (is_blockwise) {
auto const& ws = gemm_workspace_.block_scales.value();
host_data_SFA.resize(ws.SFA_ptr_array_host[group_idx]->bytes());
ptr_SFA = host_data_SFA.data();
ws.SFA_ptr_array_host[group_idx]->copy_to_host(ptr_SFA);
host_data_SFB.resize(ws.SFB_ptr_array_host[group_idx]->bytes());
ptr_SFB = host_data_SFB.data();
ws.SFB_ptr_array_host[group_idx]->copy_to_host(ptr_SFB);
}
}
const auto &desc = static_cast<library::GroupedGemmDescription const &>(operation->description());
const auto& gemm_desc = desc.gemm;
if (!is_block_scaled) {
if (!is_block_scaled and !is_blockwise) {
library::Handle handle;
handle.set_provider(provider);
@ -1112,7 +1191,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride());
}
else {
else if (is_block_scaled) {
auto const& block_scale_desc = desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
@ -1134,7 +1213,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
gemm_desc.D.layout,
block_scale_desc.SFD.element,
block_scale_desc.SFD.layout,
block_scale_desc.SFVecSize,
block_scale_desc.SFKVecSize,
block_scale_desc.EpilogueSFVecSize);
auto operators_it =
@ -1208,6 +1287,100 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
block_scale_ws.SFD_reference_ptr_array_host[group_idx]->copy_from_host(ptr_SFD);
}
else {
// Blockwise
auto const& block_scale_desc = desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
library::BlockwiseGemmFunctionalKey blockwiseGemm_key(
library::Provider::kReferenceHost,
library::GemmKind::kUniversal,
library::OperationKind::kBlockwiseGemm,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
element_A,
gemm_desc.A.layout,
block_scale_desc.SFA.element,
element_B,
gemm_desc.B.layout,
block_scale_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
block_scale_desc.SFMVecSize,
block_scale_desc.SFNVecSize,
block_scale_desc.SFKVecSize
);
auto operators_it = library::Singleton::get().operation_table.blockwise_gemm_operations.find(blockwiseGemm_key);
if (
operators_it ==
library::Singleton::get().operation_table.blockwise_gemm_operations.end()) {
disposition = Disposition::kNotSupported;
break;
}
if (operators_it->second.empty()) {
disposition = Disposition::kNotSupported;
break;
}
auto cc_it = operators_it->second.begin();
if (cc_it == operators_it->second.end()) {
disposition = Disposition::kNotSupported;
break;
}
// host reference has only one instances in BlockScaledOperationVectorMap
library::Operation const* reference_op = cc_it->second[0];
library::BlockwiseGemmArguments arguments {
{int(problem_.m(group_idx)), int(problem_.n(group_idx)), int(problem_.k(group_idx))},
{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)},
{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)},
1, // batch_count
ptr_A,
ptr_B,
ptr_SFA,
ptr_SFB,
ptr_C,
ptr_D,
problem_.alpha.data(),
problem_.beta.data(),
library::ScalarPointerMode::kHost,
problem_.lda[group_idx],
problem_.ldb[group_idx],
problem_.ldc[group_idx],
problem_.ldc[group_idx],
gemm_workspace_.A_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.B_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride(),
};
library::GemmUniversalConfiguration configuration{
library::GemmUniversalMode::kGemm,
problem_.problem_sizes[group_idx],
{problem_.cluster_m, problem_.cluster_n, problem_.cluster_k},
{problem_.cluster_m_fallback, problem_.cluster_n_fallback, problem_.cluster_k_fallback},
1,
problem_.lda[group_idx],
problem_.ldb[group_idx],
problem_.ldc[group_idx],
problem_.ldc[group_idx],
1,
};
uint64_t host_workspace_size_needed = reference_op->get_host_workspace_size(&gemm_workspace_.configuration);
std::vector<char> host_workspace(host_workspace_size_needed);
status = reference_op->initialize(&configuration, host_workspace.data());
if (status != Status::kSuccess) {
break;
}
status = reference_op->run(&arguments, host_workspace.data());
}
if (status != Status::kSuccess) {
break;
}
@ -1292,6 +1465,10 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
void* device_workspace) {
library::Operation const* underlying_operation = operation;
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
if (results_.back().status != Status::kSuccess) {
return results_.back().status;
}
auto func = [&](cudaStream_t stream, int iteration) {
// Iterate over copies of the problem in memory

View File

@ -301,6 +301,9 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
out << "kBlockScaledGemm";
}
else if (op_kind == library::OperationKind::kBlockwiseGemm) {
out << "kBlockwiseGemm";
}
else if (op_kind == library::OperationKind::kRankK) {
out << "kRankK";
}