|
|
|
|
@ -108,6 +108,14 @@ GroupedGemmOperationProfiler::GroupedGemmOperationProfiler(Options const& option
|
|
|
|
|
{ArgumentTypeID::kScalar,
|
|
|
|
|
{"beta", "epilogue::beta"},
|
|
|
|
|
"Epilogue scalar beta (applied to all GEMMs in group)."},
|
|
|
|
|
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"},
|
|
|
|
|
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
|
|
|
|
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"},
|
|
|
|
|
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
|
|
|
|
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"},
|
|
|
|
|
"Raster order (heuristic, along_n, along_m)"},
|
|
|
|
|
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
|
|
|
|
|
{ArgumentTypeID::kEnumerated, {"use_pdl", "use_pdl"}, "Use PDL (true, false)"},
|
|
|
|
|
{ArgumentTypeID::kScalar,
|
|
|
|
|
{"problem-sizes"},
|
|
|
|
|
"MxNxK Problem sizes for the grouped GEMM, where a group is enclosed by `[]`. E.g. "
|
|
|
|
|
@ -236,6 +244,9 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
|
|
|
|
if (!file.good()) {
|
|
|
|
|
throw std::runtime_error("Failed to open file: " + problem_file);
|
|
|
|
|
}
|
|
|
|
|
// clear the problem sizes and 3x problem sizes from previous operation
|
|
|
|
|
problem_sizes.clear();
|
|
|
|
|
problem_sizes_3x.clear();
|
|
|
|
|
|
|
|
|
|
for (std::string line; std::getline(file, line);) {
|
|
|
|
|
std::istringstream iss(line);
|
|
|
|
|
@ -257,7 +268,7 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->cluster_m = 1;
|
|
|
|
|
this->cluster_m = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) {
|
|
|
|
|
@ -272,17 +283,17 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->cluster_m_fallback = 0;
|
|
|
|
|
this->cluster_m_fallback = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->cluster_n_fallback = 0;
|
|
|
|
|
this->cluster_n_fallback = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->cluster_k_fallback = 0;
|
|
|
|
|
this->cluster_k_fallback = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
this->mode = library::GemmUniversalMode::kGrouped;
|
|
|
|
|
@ -303,6 +314,31 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
|
|
|
|
return Status::kErrorInvalidProblem;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->use_pdl = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->swizzle_size = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) {
|
|
|
|
|
// default value
|
|
|
|
|
this->raster_order = library::RasterOrder::kHeuristic;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!arg_as_scalar(
|
|
|
|
|
this->alpha,
|
|
|
|
|
operation_desc.gemm.element_epilogue,
|
|
|
|
|
@ -348,6 +384,19 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
|
|
|
|
.front();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// instantiation for exploration profiling
|
|
|
|
|
this->raster_orders = {
|
|
|
|
|
cutlass::library::RasterOrder::kAlongN,
|
|
|
|
|
cutlass::library::RasterOrder::kAlongM
|
|
|
|
|
};
|
|
|
|
|
this->swizzle_sizes = {1, 2, 4, 8};
|
|
|
|
|
this->preferred_clusters = {
|
|
|
|
|
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}, {4, 1, 1}, {4, 2, 1}, {4, 4, 1}, {8, 2, 1}
|
|
|
|
|
};
|
|
|
|
|
this->fallback_clusters = {
|
|
|
|
|
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return Status::kSuccess;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -469,6 +518,13 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
|
|
|
|
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
|
|
|
|
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
|
|
|
|
|
|
|
|
|
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
|
|
|
|
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
|
|
|
|
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
|
|
|
|
|
|
|
|
|
|
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
|
|
|
|
|
set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b));
|
|
|
|
|
|
|
|
|
|
set_argument(
|
|
|
|
|
result,
|
|
|
|
|
"alpha",
|
|
|
|
|
@ -482,6 +538,25 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
|
|
|
|
library::lexical_cast(beta, operation_desc.gemm.element_epilogue));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GroupedGemmOperationProfiler::update_result_(
|
|
|
|
|
PerformanceResult &result,
|
|
|
|
|
ProblemSpace const &problem_space,
|
|
|
|
|
cutlass::library::RasterOrder const &raster_order,
|
|
|
|
|
std::array<int64_t, 3> const &preferred_cluster,
|
|
|
|
|
std::array<int64_t, 3> const &fallback_cluster,
|
|
|
|
|
int swizzle_size
|
|
|
|
|
) {
|
|
|
|
|
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
|
|
|
|
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
|
|
|
|
|
|
|
|
|
set_argument(result, "cluster_m", problem_space, preferred_cluster[0]);
|
|
|
|
|
set_argument(result, "cluster_n", problem_space, preferred_cluster[1]);
|
|
|
|
|
set_argument(result, "cluster_k", problem_space, preferred_cluster[2]);
|
|
|
|
|
set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]);
|
|
|
|
|
set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]);
|
|
|
|
|
set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
/// Extracts the problem dimensions
|
|
|
|
|
@ -506,7 +581,6 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
|
|
|
|
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
|
|
|
|
|
datatypes_regex + ")_" + sf_tuple + "(" +
|
|
|
|
|
datatypes_regex + ")x(" + datatypes_regex + ")";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
|
|
|
|
|
is_block_scaled = true;
|
|
|
|
|
@ -538,6 +612,14 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
|
|
|
|
config.ldc = problem_.ldc.data();
|
|
|
|
|
config.problem_sizes_3x_host = problem_.problem_sizes_3x.data();
|
|
|
|
|
|
|
|
|
|
gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size;
|
|
|
|
|
gemm_workspace_.arguments.raster_order = problem_.raster_order;
|
|
|
|
|
|
|
|
|
|
gemm_workspace_.arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a;
|
|
|
|
|
gemm_workspace_.arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b;
|
|
|
|
|
|
|
|
|
|
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
|
|
|
|
|
|
|
|
|
|
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
|
|
|
|
|
|
|
|
|
return status;
|
|
|
|
|
@ -1000,6 +1082,25 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
|
|
|
|
auto const& desc =
|
|
|
|
|
static_cast<library::GroupedGemmDescription const&>(operation->description());
|
|
|
|
|
|
|
|
|
|
cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.arguments.runtime_input_datatype_a;
|
|
|
|
|
cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.arguments.runtime_input_datatype_b;
|
|
|
|
|
|
|
|
|
|
bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic;
|
|
|
|
|
bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic;
|
|
|
|
|
|
|
|
|
|
assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static.");
|
|
|
|
|
|
|
|
|
|
cutlass::library::NumericTypeID element_A = desc.gemm.A.element;
|
|
|
|
|
cutlass::library::NumericTypeID element_B = desc.gemm.B.element;
|
|
|
|
|
|
|
|
|
|
if (is_runtime_datatype_a) {
|
|
|
|
|
element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_runtime_datatype_b) {
|
|
|
|
|
element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool verification_status = verify_with_reference_(
|
|
|
|
|
options,
|
|
|
|
|
report,
|
|
|
|
|
@ -1007,8 +1108,8 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
|
|
|
|
operation,
|
|
|
|
|
problem_space,
|
|
|
|
|
problem,
|
|
|
|
|
desc.gemm.A.element,
|
|
|
|
|
desc.gemm.B.element);
|
|
|
|
|
element_A,
|
|
|
|
|
element_B);
|
|
|
|
|
|
|
|
|
|
// Update disposition to worst case verification outcome among all
|
|
|
|
|
// verification providers which are supported
|
|
|
|
|
@ -1442,13 +1543,23 @@ bool GroupedGemmOperationProfiler::profile(
|
|
|
|
|
ProblemSpace::Problem const& problem) {
|
|
|
|
|
|
|
|
|
|
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
|
|
|
|
results_.back().status = profile_cutlass_(
|
|
|
|
|
results_.back(),
|
|
|
|
|
options,
|
|
|
|
|
operation,
|
|
|
|
|
&gemm_workspace_.arguments,
|
|
|
|
|
gemm_workspace_.host_workspace.data(),
|
|
|
|
|
gemm_workspace_.device_workspace.data());
|
|
|
|
|
if (options.profiling.enable_kernel_performance_search) {
|
|
|
|
|
std::cerr << "Exhaustive performance search is not available for Grouped GEMMs. "
|
|
|
|
|
<< "Please use --enable-best-kernel-for-fixed-shape to profile a specific problem size "
|
|
|
|
|
<< "with --problem-sizes or --problem-sizes-file.\n";
|
|
|
|
|
}
|
|
|
|
|
else if (options.profiling.enable_best_kernel_for_fixed_shape) {
|
|
|
|
|
return profile_cutlass_for_fixed_shape_(options, operation, problem_space);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
results_.back().status = profile_cutlass_(
|
|
|
|
|
results_.back(),
|
|
|
|
|
options,
|
|
|
|
|
operation,
|
|
|
|
|
&gemm_workspace_.arguments,
|
|
|
|
|
gemm_workspace_.host_workspace.data(),
|
|
|
|
|
gemm_workspace_.device_workspace.data());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
@ -1463,7 +1574,6 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
|
|
|
|
void* arguments,
|
|
|
|
|
void* host_workspace,
|
|
|
|
|
void* device_workspace) {
|
|
|
|
|
|
|
|
|
|
library::Operation const* underlying_operation = operation;
|
|
|
|
|
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
|
|
|
|
if (results_.back().status != Status::kSuccess) {
|
|
|
|
|
@ -1487,6 +1597,97 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
|
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
/// Method to profile a CUTLASS Operation for the best configuration for a fixed shape
|
|
|
|
|
bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_(
|
|
|
|
|
Options const& options,
|
|
|
|
|
library::Operation const* operation,
|
|
|
|
|
ProblemSpace const& problem_space) {
|
|
|
|
|
library::GroupedGemmDescription const &operation_desc =
|
|
|
|
|
static_cast<library::GroupedGemmDescription const &>(operation->description());
|
|
|
|
|
|
|
|
|
|
auto min_cc = operation_desc.tile_description.minimum_compute_capability;
|
|
|
|
|
|
|
|
|
|
bool is_dynamic_cluster_enabled = (min_cc >= 100);
|
|
|
|
|
|
|
|
|
|
// Helper function to test validity of fallback cluster shapes and preferred cluster shapes.
|
|
|
|
|
auto is_valid_dynamic_cluster_shape = [](const std::array<int64_t, 3>& preferred_cluster, const std::array<int64_t, 3>& fallback_cluster) {
|
|
|
|
|
for (size_t i = 0; i < 3; ++i) {
|
|
|
|
|
if (preferred_cluster[i] % fallback_cluster[i] != 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Helper function to select the best performance number among a list.
|
|
|
|
|
auto select_best_candidate = [&](std::vector<PerformanceResult> &candidates) {
|
|
|
|
|
assert(!candidates.empty() && "Candidates vector should not be empty");
|
|
|
|
|
auto best_iter = std::max_element(
|
|
|
|
|
candidates.begin(), candidates.end(),
|
|
|
|
|
[](PerformanceResult const &a, PerformanceResult const &b) {
|
|
|
|
|
return a.gflops_per_sec() < b.gflops_per_sec();
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
assert(best_iter != candidates.end() && "No candidate found despite non-empty candidates vector");
|
|
|
|
|
results_.push_back(std::move(*best_iter));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<PerformanceResult> candidates;
|
|
|
|
|
PerformanceResult result_base = results_.back();
|
|
|
|
|
results_.pop_back();
|
|
|
|
|
|
|
|
|
|
bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 ||
|
|
|
|
|
int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 ||
|
|
|
|
|
int64_t(operation_desc.tile_description.cluster_shape.k()) == 0;
|
|
|
|
|
|
|
|
|
|
std::vector<std::array<int64_t, 3>> preferred_clusters;
|
|
|
|
|
std::vector<std::array<int64_t, 3>> fallback_clusters;
|
|
|
|
|
|
|
|
|
|
// Only loop over built-in cluster shape lists for dynamic cluster kernels
|
|
|
|
|
// and for kernels that can leverage the dynamic cluster feature.
|
|
|
|
|
if (dynamic_cluster && is_dynamic_cluster_enabled) {
|
|
|
|
|
preferred_clusters = this->problem_.preferred_clusters;
|
|
|
|
|
fallback_clusters = this->problem_.fallback_clusters;
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}};
|
|
|
|
|
fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto preferred_cluster : preferred_clusters) {
|
|
|
|
|
for (auto fallback_cluster : fallback_clusters) {
|
|
|
|
|
if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto swizzle_size : this->problem_.swizzle_sizes) {
|
|
|
|
|
for (auto raster_order : this->problem_.raster_orders) {
|
|
|
|
|
PerformanceResult curr_result(result_base);
|
|
|
|
|
update_result_(curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size);
|
|
|
|
|
curr_result.status = profile_cutlass_(
|
|
|
|
|
curr_result,
|
|
|
|
|
options,
|
|
|
|
|
operation,
|
|
|
|
|
&gemm_workspace_.arguments,
|
|
|
|
|
gemm_workspace_.host_workspace.data(),
|
|
|
|
|
gemm_workspace_.device_workspace.data()
|
|
|
|
|
);
|
|
|
|
|
if (curr_result.status == Status::kSuccess) { // Only add valid results
|
|
|
|
|
candidates.push_back(curr_result);
|
|
|
|
|
}
|
|
|
|
|
}// for raster_order
|
|
|
|
|
}// for swizzle_size
|
|
|
|
|
}// for fallback_cluster
|
|
|
|
|
}// for preferred_clusters
|
|
|
|
|
|
|
|
|
|
if (candidates.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
select_best_candidate(candidates);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
} // namespace profiler
|
|
|
|
|
} // namespace cutlass
|
|
|
|
|
|
|
|
|
|
|