cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -437,9 +437,11 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
1299
tools/profiler/src/blockwise_gemm_operation_profiler.cu
Normal file
1299
tools/profiler/src/blockwise_gemm_operation_profiler.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,7 @@
|
||||
|
||||
// Profiler includes
|
||||
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/blockwise_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv2d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv3d_operation_profiler.h"
|
||||
#include "cutlass/profiler/cutlass_profiler.h"
|
||||
@ -64,6 +65,8 @@ CutlassProfiler::CutlassProfiler(
|
||||
|
||||
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));
|
||||
|
||||
@ -440,10 +440,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "n", problem_space, n);
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
@ -459,9 +460,11 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
||||
set_argument(result, "problem-sizes", problem_space, ss.str());
|
||||
}
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.gemm.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
@ -497,10 +500,22 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
||||
// We distinguish between block scaled and non-block scaled operations by looking at the kernel
|
||||
// name, which tells us what reference kernel to use, which arguments to pass to the operation
|
||||
// etc. This avoids creating yet another OperationProfiler with a lot of boilerplate in it.
|
||||
|
||||
std::string sf_tuple = "\\d+x\\d+";
|
||||
std::string datatypes_regex = "\\w?f\\d+|e\\dm\\d"; // bf16 | f16 | f32 | e4m3 | ...
|
||||
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
|
||||
datatypes_regex + ")_" + sf_tuple + "(" +
|
||||
datatypes_regex + ")x(" + datatypes_regex + ")";
|
||||
|
||||
|
||||
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
|
||||
is_block_scaled = true;
|
||||
gemm_workspace_.block_scales = BlockScalingWorkspace{};
|
||||
}
|
||||
else if (std::regex_search(operation_desc.gemm.name, std::regex(blockwise_regex_string))) {
|
||||
is_blockwise = true;
|
||||
gemm_workspace_.block_scales = BlockScalingWorkspace{};
|
||||
}
|
||||
else {
|
||||
is_block_scaled = false;
|
||||
gemm_workspace_.block_scales = std::nullopt;
|
||||
@ -605,6 +620,12 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
block_scaling_ws.SFD_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFD_reference_ptr_array_host.resize(num_groups);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto& block_scaling_ws = gemm_workspace_.block_scales.value();
|
||||
block_scaling_ws.SFA_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFB_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFC_ptr_array_host.resize(num_groups);
|
||||
}
|
||||
static_assert(sizeof(void*) == 8); // allocating blocks for pointers, so verify pointer size
|
||||
// ldx
|
||||
gemm_workspace_.lda_array_device =
|
||||
@ -698,7 +719,7 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
int sfa_m = round_up(int(problem_.m(group_idx)), 128);
|
||||
int sfb_n = round_up(int(problem_.n(group_idx)), 128);
|
||||
int sfa_sfb_k =
|
||||
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFVecSize), 4);
|
||||
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize), 4);
|
||||
|
||||
int sfd_m =
|
||||
block_scale_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor
|
||||
@ -760,6 +781,37 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
block_scale_ws.SFD_ptr_array_host[group_idx]->fill_device(0);
|
||||
}
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto const block_scale_desc = operation_desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
int sfa_m = ceil_div(int(problem_.m(group_idx)), block_scale_desc.SFMVecSize);
|
||||
int sfb_n = ceil_div(int(problem_.n(group_idx)), block_scale_desc.SFNVecSize);
|
||||
int sfa_sfb_k = ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize);
|
||||
|
||||
block_scale_ws.SFA_ptr_array_host[group_idx] =
|
||||
device_context.allocate_and_initialize_tensor(
|
||||
options,
|
||||
"SFA_" + std::to_string(group_idx),
|
||||
block_scale_desc.SFA.element,
|
||||
block_scale_desc.SFA.layout,
|
||||
{sfa_m, sfa_sfb_k},
|
||||
{sfa_m},
|
||||
gemm_workspace_.problem_count,
|
||||
seed_shift++,
|
||||
0);
|
||||
|
||||
block_scale_ws.SFB_ptr_array_host[group_idx] =
|
||||
device_context.allocate_and_initialize_tensor(
|
||||
options,
|
||||
"SFB_" + std::to_string(group_idx),
|
||||
block_scale_desc.SFB.element,
|
||||
block_scale_desc.SFB.layout,
|
||||
{sfa_sfb_k, sfb_n},
|
||||
{sfb_n},
|
||||
gemm_workspace_.problem_count,
|
||||
seed_shift++,
|
||||
0);
|
||||
}
|
||||
}
|
||||
|
||||
// takes the allocated tensors and initializes an array of pointers per problem in the workspace
|
||||
@ -825,6 +877,18 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
0 // device_index
|
||||
);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
create_dev_ptr_array_all_workspace(
|
||||
block_scale_ws.SFA_ptr_array_device,
|
||||
block_scale_ws.SFA_ptr_array_host,
|
||||
"SFA");
|
||||
create_dev_ptr_array_all_workspace(
|
||||
block_scale_ws.SFB_ptr_array_device,
|
||||
block_scale_ws.SFB_ptr_array_host,
|
||||
"SFB");
|
||||
}
|
||||
|
||||
init_arguments(options);
|
||||
}
|
||||
|
||||
@ -896,6 +960,11 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
||||
init_arguments(options);
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
results_.back().status = underlying_operation->run(
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
@ -998,7 +1067,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
}
|
||||
|
||||
// we only have a block scaled reference kernel implemented on the host
|
||||
if (is_block_scaled && provider != library::Provider::kReferenceHost) {
|
||||
if ((is_block_scaled || is_blockwise) && provider != library::Provider::kReferenceHost) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1064,12 +1133,22 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
ptr_norm_constant = host_data_norm_constant.data();
|
||||
ws.norm_constant->copy_to_host(ptr_norm_constant);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto const& ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
host_data_SFA.resize(ws.SFA_ptr_array_host[group_idx]->bytes());
|
||||
ptr_SFA = host_data_SFA.data();
|
||||
ws.SFA_ptr_array_host[group_idx]->copy_to_host(ptr_SFA);
|
||||
host_data_SFB.resize(ws.SFB_ptr_array_host[group_idx]->bytes());
|
||||
ptr_SFB = host_data_SFB.data();
|
||||
ws.SFB_ptr_array_host[group_idx]->copy_to_host(ptr_SFB);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &desc = static_cast<library::GroupedGemmDescription const &>(operation->description());
|
||||
const auto& gemm_desc = desc.gemm;
|
||||
|
||||
if (!is_block_scaled) {
|
||||
if (!is_block_scaled and !is_blockwise) {
|
||||
library::Handle handle;
|
||||
handle.set_provider(provider);
|
||||
|
||||
@ -1112,7 +1191,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride());
|
||||
}
|
||||
else {
|
||||
else if (is_block_scaled) {
|
||||
auto const& block_scale_desc = desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
@ -1134,7 +1213,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFD.element,
|
||||
block_scale_desc.SFD.layout,
|
||||
block_scale_desc.SFVecSize,
|
||||
block_scale_desc.SFKVecSize,
|
||||
block_scale_desc.EpilogueSFVecSize);
|
||||
|
||||
auto operators_it =
|
||||
@ -1208,6 +1287,100 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
|
||||
block_scale_ws.SFD_reference_ptr_array_host[group_idx]->copy_from_host(ptr_SFD);
|
||||
}
|
||||
else {
|
||||
// Blockwise
|
||||
auto const& block_scale_desc = desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
library::BlockwiseGemmFunctionalKey blockwiseGemm_key(
|
||||
library::Provider::kReferenceHost,
|
||||
library::GemmKind::kUniversal,
|
||||
library::OperationKind::kBlockwiseGemm,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
element_A,
|
||||
gemm_desc.A.layout,
|
||||
block_scale_desc.SFA.element,
|
||||
element_B,
|
||||
gemm_desc.B.layout,
|
||||
block_scale_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFMVecSize,
|
||||
block_scale_desc.SFNVecSize,
|
||||
block_scale_desc.SFKVecSize
|
||||
);
|
||||
|
||||
auto operators_it = library::Singleton::get().operation_table.blockwise_gemm_operations.find(blockwiseGemm_key);
|
||||
if (
|
||||
operators_it ==
|
||||
library::Singleton::get().operation_table.blockwise_gemm_operations.end()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
auto cc_it = operators_it->second.begin();
|
||||
if (cc_it == operators_it->second.end()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
// host reference has only one instances in BlockScaledOperationVectorMap
|
||||
library::Operation const* reference_op = cc_it->second[0];
|
||||
|
||||
library::BlockwiseGemmArguments arguments {
|
||||
{int(problem_.m(group_idx)), int(problem_.n(group_idx)), int(problem_.k(group_idx))},
|
||||
{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)},
|
||||
{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)},
|
||||
1, // batch_count
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
ptr_SFA,
|
||||
ptr_SFB,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
problem_.alpha.data(),
|
||||
problem_.beta.data(),
|
||||
library::ScalarPointerMode::kHost,
|
||||
problem_.lda[group_idx],
|
||||
problem_.ldb[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
gemm_workspace_.A_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.B_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride(),
|
||||
};
|
||||
|
||||
library::GemmUniversalConfiguration configuration{
|
||||
library::GemmUniversalMode::kGemm,
|
||||
problem_.problem_sizes[group_idx],
|
||||
{problem_.cluster_m, problem_.cluster_n, problem_.cluster_k},
|
||||
{problem_.cluster_m_fallback, problem_.cluster_n_fallback, problem_.cluster_k_fallback},
|
||||
1,
|
||||
problem_.lda[group_idx],
|
||||
problem_.ldb[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
1,
|
||||
};
|
||||
uint64_t host_workspace_size_needed = reference_op->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||
std::vector<char> host_workspace(host_workspace_size_needed);
|
||||
status = reference_op->initialize(&configuration, host_workspace.data());
|
||||
if (status != Status::kSuccess) {
|
||||
break;
|
||||
}
|
||||
|
||||
status = reference_op->run(&arguments, host_workspace.data());
|
||||
}
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
break;
|
||||
}
|
||||
@ -1292,6 +1465,10 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
||||
void* device_workspace) {
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
return results_.back().status;
|
||||
}
|
||||
|
||||
auto func = [&](cudaStream_t stream, int iteration) {
|
||||
// Iterate over copies of the problem in memory
|
||||
|
||||
@ -301,6 +301,9 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
|
||||
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
|
||||
out << "kBlockScaledGemm";
|
||||
}
|
||||
else if (op_kind == library::OperationKind::kBlockwiseGemm) {
|
||||
out << "kBlockwiseGemm";
|
||||
}
|
||||
else if (op_kind == library::OperationKind::kRankK) {
|
||||
out << "kRankK";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user