v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -907,7 +907,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
}
//

View File

@ -749,7 +749,7 @@ Status BlockwiseGemmOperationProfiler::initialize_workspace(
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
}
//

View File

@ -977,7 +977,7 @@ Status GemmOperationProfiler::initialize_workspace(
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount;
gemm_workspace_[i].arguments.sm_count = options.device.get_sm_count(i);
gemm_workspace_[i].arguments.device_index = static_cast<int>(i);
}
}

View File

@ -108,6 +108,14 @@ GroupedGemmOperationProfiler::GroupedGemmOperationProfiler(Options const& option
{ArgumentTypeID::kScalar,
{"beta", "epilogue::beta"},
"Epilogue scalar beta (applied to all GEMMs in group)."},
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"},
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"},
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"},
"Raster order (heuristic, along_n, along_m)"},
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
{ArgumentTypeID::kEnumerated, {"use_pdl", "use_pdl"}, "Use PDL (true, false)"},
{ArgumentTypeID::kScalar,
{"problem-sizes"},
"MxNxK Problem sizes for the grouped GEMM, where a group is enclosed by `[]`. E.g. "
@ -236,6 +244,9 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
if (!file.good()) {
throw std::runtime_error("Failed to open file: " + problem_file);
}
// clear the problem sizes and 3x problem sizes from previous operation
problem_sizes.clear();
problem_sizes_3x.clear();
for (std::string line; std::getline(file, line);) {
std::istringstream iss(line);
@ -257,7 +268,7 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) {
// default value
this->cluster_m = 1;
this->cluster_m = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
}
if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) {
@ -272,17 +283,17 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) {
// default value
this->cluster_m_fallback = 0;
this->cluster_m_fallback = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
}
if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) {
// default value
this->cluster_n_fallback = 0;
this->cluster_n_fallback = 1;
}
if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) {
// default value
this->cluster_k_fallback = 0;
this->cluster_k_fallback = 1;
}
this->mode = library::GemmUniversalMode::kGrouped;
@ -303,6 +314,31 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
return Status::kErrorInvalidProblem;
}
if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) {
// default value
this->use_pdl = false;
}
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) {
// default value
this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic;
}
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) {
// default value
this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic;
}
if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) {
// default value
this->swizzle_size = 1;
}
if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) {
// default value
this->raster_order = library::RasterOrder::kHeuristic;
}
if (!arg_as_scalar(
this->alpha,
operation_desc.gemm.element_epilogue,
@ -348,6 +384,19 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
.front();
}
// instantiation for exploration profiling
this->raster_orders = {
cutlass::library::RasterOrder::kAlongN,
cutlass::library::RasterOrder::kAlongM
};
this->swizzle_sizes = {1, 2, 4, 8};
this->preferred_clusters = {
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}, {4, 1, 1}, {4, 2, 1}, {4, 4, 1}, {8, 2, 1}
};
this->fallback_clusters = {
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}
};
return Status::kSuccess;
}
@ -469,6 +518,13 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
set_argument(result, "swizzle_size", problem_space, swizzle_size);
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b));
set_argument(
result,
"alpha",
@ -482,6 +538,25 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
library::lexical_cast(beta, operation_desc.gemm.element_epilogue));
}
void GroupedGemmOperationProfiler::update_result_(
PerformanceResult &result,
ProblemSpace const &problem_space,
cutlass::library::RasterOrder const &raster_order,
std::array<int64_t, 3> const &preferred_cluster,
std::array<int64_t, 3> const &fallback_cluster,
int swizzle_size
) {
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
set_argument(result, "swizzle_size", problem_space, swizzle_size);
set_argument(result, "cluster_m", problem_space, preferred_cluster[0]);
set_argument(result, "cluster_n", problem_space, preferred_cluster[1]);
set_argument(result, "cluster_k", problem_space, preferred_cluster[2]);
set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]);
set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]);
set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Extracts the problem dimensions
@ -506,7 +581,6 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
datatypes_regex + ")_" + sf_tuple + "(" +
datatypes_regex + ")x(" + datatypes_regex + ")";
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
is_block_scaled = true;
@ -538,6 +612,14 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
config.ldc = problem_.ldc.data();
config.problem_sizes_3x_host = problem_.problem_sizes_3x.data();
gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size;
gemm_workspace_.arguments.raster_order = problem_.raster_order;
gemm_workspace_.arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a;
gemm_workspace_.arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b;
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
initialize_result_(this->model_result_, options, operation_desc, problem_space);
return status;
@ -1000,6 +1082,25 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
auto const& desc =
static_cast<library::GroupedGemmDescription const&>(operation->description());
cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.arguments.runtime_input_datatype_a;
cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.arguments.runtime_input_datatype_b;
bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic;
bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic;
assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static.");
cutlass::library::NumericTypeID element_A = desc.gemm.A.element;
cutlass::library::NumericTypeID element_B = desc.gemm.B.element;
if (is_runtime_datatype_a) {
element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a);
}
if (is_runtime_datatype_b) {
element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b);
}
bool verification_status = verify_with_reference_(
options,
report,
@ -1007,8 +1108,8 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
operation,
problem_space,
problem,
desc.gemm.A.element,
desc.gemm.B.element);
element_A,
element_B);
// Update disposition to worst case verification outcome among all
// verification providers which are supported
@ -1442,13 +1543,23 @@ bool GroupedGemmOperationProfiler::profile(
ProblemSpace::Problem const& problem) {
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
results_.back().status = profile_cutlass_(
results_.back(),
options,
operation,
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
if (options.profiling.enable_kernel_performance_search) {
std::cerr << "Exhaustive performance search is not available for Grouped GEMMs. "
<< "Please use --enable-best-kernel-for-fixed-shape to profile a specific problem size "
<< "with --problem-sizes or --problem-sizes-file.\n";
}
else if (options.profiling.enable_best_kernel_for_fixed_shape) {
return profile_cutlass_for_fixed_shape_(options, operation, problem_space);
}
else {
results_.back().status = profile_cutlass_(
results_.back(),
options,
operation,
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
}
}
return true;
}
@ -1463,7 +1574,6 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
void* arguments,
void* host_workspace,
void* device_workspace) {
library::Operation const* underlying_operation = operation;
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
if (results_.back().status != Status::kSuccess) {
@ -1487,6 +1597,97 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Method to profile a CUTLASS Operation for the best configuration for a fixed shape
bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_(
Options const& options,
library::Operation const* operation,
ProblemSpace const& problem_space) {
library::GroupedGemmDescription const &operation_desc =
static_cast<library::GroupedGemmDescription const &>(operation->description());
auto min_cc = operation_desc.tile_description.minimum_compute_capability;
bool is_dynamic_cluster_enabled = (min_cc >= 100);
// Helper function to test validity of fallback cluster shapes and preferred cluster shapes.
auto is_valid_dynamic_cluster_shape = [](const std::array<int64_t, 3>& preferred_cluster, const std::array<int64_t, 3>& fallback_cluster) {
for (size_t i = 0; i < 3; ++i) {
if (preferred_cluster[i] % fallback_cluster[i] != 0) {
return false;
}
}
return true;
};
// Helper function to select the best performance number among a list.
auto select_best_candidate = [&](std::vector<PerformanceResult> &candidates) {
assert(!candidates.empty() && "Candidates vector should not be empty");
auto best_iter = std::max_element(
candidates.begin(), candidates.end(),
[](PerformanceResult const &a, PerformanceResult const &b) {
return a.gflops_per_sec() < b.gflops_per_sec();
}
);
assert(best_iter != candidates.end() && "No candidate found despite non-empty candidates vector");
results_.push_back(std::move(*best_iter));
};
std::vector<PerformanceResult> candidates;
PerformanceResult result_base = results_.back();
results_.pop_back();
bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 ||
int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 ||
int64_t(operation_desc.tile_description.cluster_shape.k()) == 0;
std::vector<std::array<int64_t, 3>> preferred_clusters;
std::vector<std::array<int64_t, 3>> fallback_clusters;
// Only loop over built-in cluster shape lists for dynamic cluster kernels
// and for kernels that can leverage the dynamic cluster feature.
if (dynamic_cluster && is_dynamic_cluster_enabled) {
preferred_clusters = this->problem_.preferred_clusters;
fallback_clusters = this->problem_.fallback_clusters;
}
else {
preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}};
fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}};
}
for (auto preferred_cluster : preferred_clusters) {
for (auto fallback_cluster : fallback_clusters) {
if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) {
continue;
}
for (auto swizzle_size : this->problem_.swizzle_sizes) {
for (auto raster_order : this->problem_.raster_orders) {
PerformanceResult curr_result(result_base);
update_result_(curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size);
curr_result.status = profile_cutlass_(
curr_result,
options,
operation,
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data()
);
if (curr_result.status == Status::kSuccess) { // Only add valid results
candidates.push_back(curr_result);
}
}// for raster_order
}// for swizzle_size
}// for fallback_cluster
}// for preferred_clusters
if (candidates.empty()) {
return false;
}
select_best_candidate(candidates);
return true;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler
} // namespace cutlass

View File

@ -141,9 +141,18 @@ Options::Device::Device(cutlass::CommandLine const &cmdline) {
}
}
// Permit overriding the sm_count
cmdline.get_cmd_line_argument("sm-count", sm_count, 0);
}
}
int Options::Device::get_sm_count(int device_index) const {
if (sm_count <= 0) {
return properties[device_index].multiProcessorCount;
}
return sm_count;
}
void Options::Device::print_usage(std::ostream &out) const {
out << "Device:\n"
@ -185,7 +194,12 @@ void Options::Device::print_usage(std::ostream &out) const {
<< " --llc-capacity=<capacity in KiB> "
<< " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line
<< " profiling phases cycle through different input tensors to induce" << end_of_line
<< " capacity misses in the L2.\n\n";
<< " capacity misses in the L2.\n\n"
<< " --sm-count=<int> "
<< " Override the number of SMs. This is used to limit the number of " << end_of_line
<< " during profiling. If this is set, profiling attempts to limit the sm_count " << end_of_line
<< " to user-set value. This is not possible on all architectures and all kernel types. \n\n";
}