v4.0 update. (#2371)
This commit is contained in:
@ -576,6 +576,11 @@ struct GemmGroupedArguments {
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
|
||||
library::RasterOrder raster_order{};
|
||||
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
|
||||
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
|
||||
int swizzle_size{1};
|
||||
|
||||
// these should really be in the configuration but staying consistent with GEMM
|
||||
int sm_count{0};
|
||||
int max_active_clusters{0};
|
||||
|
||||
@ -64,6 +64,13 @@ public:
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
||||
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
||||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
|
||||
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
|
||||
GroupedGemmOperation3xBase(char const* name = "unknown_gemm")
|
||||
: GemmOperation3xBase<Operator_>(name, GemmKind::kGrouped) {
|
||||
this->description_.kind = OperationKind::kGroupedGemm;
|
||||
@ -152,8 +159,65 @@ protected:
|
||||
arguments.problem_sizes_3x,
|
||||
arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host
|
||||
: nullptr};
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const**>(arguments.ptr_A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const**>(arguments.ptr_B);
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const**>(arguments.ptr_A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const**>(arguments.ptr_B);
|
||||
|
||||
using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA;
|
||||
using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB;
|
||||
|
||||
static_assert(cute::is_same_v<RuntimeDataTypeA, RuntimeDataTypeB>,
|
||||
"RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format");
|
||||
using RuntimeDatatypeArg = RuntimeDataTypeA;
|
||||
|
||||
auto mapping = [](RuntimeDatatype type) {
|
||||
if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF8F6F4Format>) {
|
||||
if (type == RuntimeDatatype::kE5M2) {
|
||||
return cute::UMMA::MXF8F6F4Format::E5M2;
|
||||
}
|
||||
else if (type == RuntimeDatatype::kE4M3) {
|
||||
return cute::UMMA::MXF8F6F4Format::E4M3;
|
||||
}
|
||||
else if (type == RuntimeDatatype::kE3M2) {
|
||||
return cute::UMMA::MXF8F6F4Format::E3M2;
|
||||
}
|
||||
else if (type == RuntimeDatatype::kE2M3) {
|
||||
return cute::UMMA::MXF8F6F4Format::E2M3;
|
||||
}
|
||||
else if (type == RuntimeDatatype::kE2M1) {
|
||||
return cute::UMMA::MXF8F6F4Format::E2M1;
|
||||
}
|
||||
else {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1
|
||||
std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl;
|
||||
#endif
|
||||
return cute::UMMA::MXF8F6F4Format::E4M3;
|
||||
}
|
||||
}
|
||||
else if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF4Format>) {
|
||||
if (type == RuntimeDatatype::kE2M1) {
|
||||
return cute::UMMA::MXF4Format::E2M1;
|
||||
}
|
||||
else {
|
||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1
|
||||
std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl;
|
||||
#endif
|
||||
return cute::UMMA::MXF4Format::E2M1;
|
||||
}
|
||||
}
|
||||
// BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
};
|
||||
operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a);
|
||||
operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b);
|
||||
}
|
||||
else {
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const**>(arguments.ptr_A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const**>(arguments.ptr_B);
|
||||
}
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const**>(arguments.ptr_C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD**>(arguments.ptr_D);
|
||||
|
||||
@ -166,10 +230,29 @@ protected:
|
||||
operator_args.epilogue.dD =
|
||||
static_cast<typename Operator::GemmKernel::InternalStrideD*>(this->strideD_device.data());
|
||||
|
||||
/* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */
|
||||
operator_args.hw_info.sm_count = arguments.sm_count;
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
operator_args.hw_info.max_active_clusters = arguments.max_active_clusters;
|
||||
}
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
|
||||
operator_args.scheduler.max_swizzle_size = arguments.swizzle_size;
|
||||
}
|
||||
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
|
||||
using Enum_t = decltype(operator_args.scheduler.raster_order);
|
||||
switch (arguments.raster_order) {
|
||||
case RasterOrder::kAlongN:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongN;
|
||||
break;
|
||||
case RasterOrder::kAlongM:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongM;
|
||||
break;
|
||||
default:
|
||||
operator_args.scheduler.raster_order = Enum_t::Heuristic;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape =
|
||||
dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k());
|
||||
@ -330,7 +413,6 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
// Set arguments that should only be set once before verifying or profiling the kernel.
|
||||
// This should encompass any expensive operations that don't vary from run to run
|
||||
// (e.g., max_active_clusters).
|
||||
@ -363,9 +445,10 @@ public:
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
|
||||
|
||||
if (args->max_active_clusters == 0) {
|
||||
return Status::kErrorInternal;
|
||||
std::cerr << "Max Active Clusters could not be queried. "
|
||||
<< "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n";
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -69,6 +69,12 @@ public:
|
||||
std::vector<gemm::GemmCoord> problem_sizes;
|
||||
std::vector<cute::Shape<int, int, int>> problem_sizes_3x;
|
||||
|
||||
/// For exploration purposes
|
||||
std::vector<std::array<int64_t, 3>> preferred_clusters;
|
||||
std::vector<std::array<int64_t, 3>> fallback_clusters;
|
||||
std::vector<cutlass::library::RasterOrder> raster_orders;
|
||||
std::vector<int> swizzle_sizes;
|
||||
|
||||
int cluster_m{1};
|
||||
int cluster_n{1};
|
||||
int cluster_k{1};
|
||||
@ -83,6 +89,14 @@ public:
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
|
||||
int swizzle_size{1};
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
|
||||
|
||||
bool use_pdl{false};
|
||||
|
||||
/// Parses the problem
|
||||
Status parse(
|
||||
library::GroupedGemmDescription const& operation_desc,
|
||||
@ -190,7 +204,7 @@ private:
|
||||
gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
arguments.sm_count = options.device.properties[0].multiProcessorCount;
|
||||
arguments.sm_count = options.device.get_sm_count(0);
|
||||
if (is_block_scaled) {
|
||||
auto& block_scaled_ws = gemm_workspace_.block_scales.value();
|
||||
arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data();
|
||||
@ -272,6 +286,15 @@ protected:
|
||||
library::GroupedGemmDescription const& operation_desc,
|
||||
ProblemSpace const& problem_space);
|
||||
|
||||
/// Update performance result configuration for exploration parameters
|
||||
void update_result_(
|
||||
PerformanceResult &result,
|
||||
ProblemSpace const &problem_space,
|
||||
cutlass::library::RasterOrder const &raster_order,
|
||||
std::array<int64_t, 3> const &preferred_cluster,
|
||||
std::array<int64_t, 3> const &fallback_cluster,
|
||||
int swizzle_size);
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool verify_with_reference_(
|
||||
Options const& options,
|
||||
@ -292,6 +315,12 @@ protected:
|
||||
void* host_workspace,
|
||||
void* device_workspace) override;
|
||||
|
||||
/// Method to profile a CUTLASS Operation for the best configuration for a fixed shape
|
||||
bool profile_cutlass_for_fixed_shape_(
|
||||
Options const& options,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space);
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -94,10 +94,15 @@ public:
|
||||
/// Total memory allocation on each device
|
||||
size_t maximum_capacity;
|
||||
|
||||
private:
|
||||
/// SM Count
|
||||
/// Limits the number of SMs to use on each device
|
||||
int sm_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
public:
|
||||
explicit Device(CommandLine const &cmdline);
|
||||
|
||||
void print_usage(std::ostream &out) const;
|
||||
@ -107,7 +112,10 @@ public:
|
||||
/// Returns the device ID from a device index
|
||||
int device_id(size_t device_index) const;
|
||||
|
||||
/// Returns the compute capability of the listed devices (e.g. 61, 60, 70, 75)
|
||||
/// Returns the sm_count if set, otherwise returns the number of SMs on the device
|
||||
int get_sm_count(int device_index) const;
|
||||
|
||||
/// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.)
|
||||
int compute_capability(int device_index) const;
|
||||
};
|
||||
|
||||
|
||||
@ -907,7 +907,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace(
|
||||
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
|
||||
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -749,7 +749,7 @@ Status BlockwiseGemmOperationProfiler::initialize_workspace(
|
||||
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
|
||||
gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -977,7 +977,7 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount;
|
||||
gemm_workspace_[i].arguments.sm_count = options.device.get_sm_count(i);
|
||||
gemm_workspace_[i].arguments.device_index = static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,6 +108,14 @@ GroupedGemmOperationProfiler::GroupedGemmOperationProfiler(Options const& option
|
||||
{ArgumentTypeID::kScalar,
|
||||
{"beta", "epilogue::beta"},
|
||||
"Epilogue scalar beta (applied to all GEMMs in group)."},
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"},
|
||||
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"},
|
||||
"Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"},
|
||||
"Raster order (heuristic, along_n, along_m)"},
|
||||
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
|
||||
{ArgumentTypeID::kEnumerated, {"use_pdl", "use_pdl"}, "Use PDL (true, false)"},
|
||||
{ArgumentTypeID::kScalar,
|
||||
{"problem-sizes"},
|
||||
"MxNxK Problem sizes for the grouped GEMM, where a group is enclosed by `[]`. E.g. "
|
||||
@ -236,6 +244,9 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error("Failed to open file: " + problem_file);
|
||||
}
|
||||
// clear the problem sizes and 3x problem sizes from previous operation
|
||||
problem_sizes.clear();
|
||||
problem_sizes_3x.clear();
|
||||
|
||||
for (std::string line; std::getline(file, line);) {
|
||||
std::istringstream iss(line);
|
||||
@ -257,7 +268,7 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
||||
|
||||
if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_m = 1;
|
||||
this->cluster_m = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) {
|
||||
@ -272,17 +283,17 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
||||
|
||||
if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_m_fallback = 0;
|
||||
this->cluster_m_fallback = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_n_fallback = 0;
|
||||
this->cluster_n_fallback = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_k_fallback = 0;
|
||||
this->cluster_k_fallback = 1;
|
||||
}
|
||||
|
||||
this->mode = library::GemmUniversalMode::kGrouped;
|
||||
@ -303,6 +314,31 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) {
|
||||
// default value
|
||||
this->use_pdl = false;
|
||||
}
|
||||
|
||||
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) {
|
||||
// default value
|
||||
this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic;
|
||||
}
|
||||
|
||||
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) {
|
||||
// default value
|
||||
this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) {
|
||||
// default value
|
||||
this->swizzle_size = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) {
|
||||
// default value
|
||||
this->raster_order = library::RasterOrder::kHeuristic;
|
||||
}
|
||||
|
||||
if (!arg_as_scalar(
|
||||
this->alpha,
|
||||
operation_desc.gemm.element_epilogue,
|
||||
@ -348,6 +384,19 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse(
|
||||
.front();
|
||||
}
|
||||
|
||||
// instantiation for exploration profiling
|
||||
this->raster_orders = {
|
||||
cutlass::library::RasterOrder::kAlongN,
|
||||
cutlass::library::RasterOrder::kAlongM
|
||||
};
|
||||
this->swizzle_sizes = {1, 2, 4, 8};
|
||||
this->preferred_clusters = {
|
||||
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}, {4, 1, 1}, {4, 2, 1}, {4, 4, 1}, {8, 2, 1}
|
||||
};
|
||||
this->fallback_clusters = {
|
||||
{1, 1, 1}, {2, 1, 1}, {2, 2, 1}
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -469,6 +518,13 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
||||
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
||||
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
|
||||
|
||||
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
|
||||
set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b));
|
||||
|
||||
set_argument(
|
||||
result,
|
||||
"alpha",
|
||||
@ -482,6 +538,25 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
||||
library::lexical_cast(beta, operation_desc.gemm.element_epilogue));
|
||||
}
|
||||
|
||||
void GroupedGemmOperationProfiler::update_result_(
|
||||
PerformanceResult &result,
|
||||
ProblemSpace const &problem_space,
|
||||
cutlass::library::RasterOrder const &raster_order,
|
||||
std::array<int64_t, 3> const &preferred_cluster,
|
||||
std::array<int64_t, 3> const &fallback_cluster,
|
||||
int swizzle_size
|
||||
) {
|
||||
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
||||
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, preferred_cluster[0]);
|
||||
set_argument(result, "cluster_n", problem_space, preferred_cluster[1]);
|
||||
set_argument(result, "cluster_k", problem_space, preferred_cluster[2]);
|
||||
set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
@ -506,7 +581,6 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
||||
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
|
||||
datatypes_regex + ")_" + sf_tuple + "(" +
|
||||
datatypes_regex + ")x(" + datatypes_regex + ")";
|
||||
|
||||
|
||||
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
|
||||
is_block_scaled = true;
|
||||
@ -538,6 +612,14 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
||||
config.ldc = problem_.ldc.data();
|
||||
config.problem_sizes_3x_host = problem_.problem_sizes_3x.data();
|
||||
|
||||
gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size;
|
||||
gemm_workspace_.arguments.raster_order = problem_.raster_order;
|
||||
|
||||
gemm_workspace_.arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a;
|
||||
gemm_workspace_.arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b;
|
||||
|
||||
gemm_workspace_.arguments.use_pdl = problem_.use_pdl;
|
||||
|
||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||
|
||||
return status;
|
||||
@ -1000,6 +1082,25 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
||||
auto const& desc =
|
||||
static_cast<library::GroupedGemmDescription const&>(operation->description());
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.arguments.runtime_input_datatype_a;
|
||||
cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.arguments.runtime_input_datatype_b;
|
||||
|
||||
bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic;
|
||||
bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic;
|
||||
|
||||
assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static.");
|
||||
|
||||
cutlass::library::NumericTypeID element_A = desc.gemm.A.element;
|
||||
cutlass::library::NumericTypeID element_B = desc.gemm.B.element;
|
||||
|
||||
if (is_runtime_datatype_a) {
|
||||
element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a);
|
||||
}
|
||||
|
||||
if (is_runtime_datatype_b) {
|
||||
element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b);
|
||||
}
|
||||
|
||||
bool verification_status = verify_with_reference_(
|
||||
options,
|
||||
report,
|
||||
@ -1007,8 +1108,8 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
||||
operation,
|
||||
problem_space,
|
||||
problem,
|
||||
desc.gemm.A.element,
|
||||
desc.gemm.B.element);
|
||||
element_A,
|
||||
element_B);
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
// verification providers which are supported
|
||||
@ -1442,13 +1543,23 @@ bool GroupedGemmOperationProfiler::profile(
|
||||
ProblemSpace::Problem const& problem) {
|
||||
|
||||
if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) {
|
||||
results_.back().status = profile_cutlass_(
|
||||
results_.back(),
|
||||
options,
|
||||
operation,
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
if (options.profiling.enable_kernel_performance_search) {
|
||||
std::cerr << "Exhaustive performance search is not available for Grouped GEMMs. "
|
||||
<< "Please use --enable-best-kernel-for-fixed-shape to profile a specific problem size "
|
||||
<< "with --problem-sizes or --problem-sizes-file.\n";
|
||||
}
|
||||
else if (options.profiling.enable_best_kernel_for_fixed_shape) {
|
||||
return profile_cutlass_for_fixed_shape_(options, operation, problem_space);
|
||||
}
|
||||
else {
|
||||
results_.back().status = profile_cutlass_(
|
||||
results_.back(),
|
||||
options,
|
||||
operation,
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -1463,7 +1574,6 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
||||
void* arguments,
|
||||
void* host_workspace,
|
||||
void* device_workspace) {
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
@ -1487,6 +1597,97 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Method to profile a CUTLASS Operation for the best configuration for a fixed shape
|
||||
bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_(
|
||||
Options const& options,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space) {
|
||||
library::GroupedGemmDescription const &operation_desc =
|
||||
static_cast<library::GroupedGemmDescription const &>(operation->description());
|
||||
|
||||
auto min_cc = operation_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
bool is_dynamic_cluster_enabled = (min_cc >= 100);
|
||||
|
||||
// Helper function to test validity of fallback cluster shapes and preferred cluster shapes.
|
||||
auto is_valid_dynamic_cluster_shape = [](const std::array<int64_t, 3>& preferred_cluster, const std::array<int64_t, 3>& fallback_cluster) {
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
if (preferred_cluster[i] % fallback_cluster[i] != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Helper function to select the best performance number among a list.
|
||||
auto select_best_candidate = [&](std::vector<PerformanceResult> &candidates) {
|
||||
assert(!candidates.empty() && "Candidates vector should not be empty");
|
||||
auto best_iter = std::max_element(
|
||||
candidates.begin(), candidates.end(),
|
||||
[](PerformanceResult const &a, PerformanceResult const &b) {
|
||||
return a.gflops_per_sec() < b.gflops_per_sec();
|
||||
}
|
||||
);
|
||||
assert(best_iter != candidates.end() && "No candidate found despite non-empty candidates vector");
|
||||
results_.push_back(std::move(*best_iter));
|
||||
};
|
||||
|
||||
std::vector<PerformanceResult> candidates;
|
||||
PerformanceResult result_base = results_.back();
|
||||
results_.pop_back();
|
||||
|
||||
bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 ||
|
||||
int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 ||
|
||||
int64_t(operation_desc.tile_description.cluster_shape.k()) == 0;
|
||||
|
||||
std::vector<std::array<int64_t, 3>> preferred_clusters;
|
||||
std::vector<std::array<int64_t, 3>> fallback_clusters;
|
||||
|
||||
// Only loop over built-in cluster shape lists for dynamic cluster kernels
|
||||
// and for kernels that can leverage the dynamic cluster feature.
|
||||
if (dynamic_cluster && is_dynamic_cluster_enabled) {
|
||||
preferred_clusters = this->problem_.preferred_clusters;
|
||||
fallback_clusters = this->problem_.fallback_clusters;
|
||||
}
|
||||
else {
|
||||
preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}};
|
||||
fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}};
|
||||
}
|
||||
|
||||
for (auto preferred_cluster : preferred_clusters) {
|
||||
for (auto fallback_cluster : fallback_clusters) {
|
||||
if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) {
|
||||
continue;
|
||||
}
|
||||
for (auto swizzle_size : this->problem_.swizzle_sizes) {
|
||||
for (auto raster_order : this->problem_.raster_orders) {
|
||||
PerformanceResult curr_result(result_base);
|
||||
update_result_(curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size);
|
||||
curr_result.status = profile_cutlass_(
|
||||
curr_result,
|
||||
options,
|
||||
operation,
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data()
|
||||
);
|
||||
if (curr_result.status == Status::kSuccess) { // Only add valid results
|
||||
candidates.push_back(curr_result);
|
||||
}
|
||||
}// for raster_order
|
||||
}// for swizzle_size
|
||||
}// for fallback_cluster
|
||||
}// for preferred_clusters
|
||||
|
||||
if (candidates.empty()) {
|
||||
return false;
|
||||
}
|
||||
select_best_candidate(candidates);
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -141,9 +141,18 @@ Options::Device::Device(cutlass::CommandLine const &cmdline) {
|
||||
}
|
||||
}
|
||||
|
||||
// Permit overriding the sm_count
|
||||
cmdline.get_cmd_line_argument("sm-count", sm_count, 0);
|
||||
}
|
||||
}
|
||||
|
||||
int Options::Device::get_sm_count(int device_index) const {
|
||||
if (sm_count <= 0) {
|
||||
return properties[device_index].multiProcessorCount;
|
||||
}
|
||||
return sm_count;
|
||||
}
|
||||
|
||||
void Options::Device::print_usage(std::ostream &out) const {
|
||||
|
||||
out << "Device:\n"
|
||||
@ -185,7 +194,12 @@ void Options::Device::print_usage(std::ostream &out) const {
|
||||
<< " --llc-capacity=<capacity in KiB> "
|
||||
<< " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line
|
||||
<< " profiling phases cycle through different input tensors to induce" << end_of_line
|
||||
<< " capacity misses in the L2.\n\n";
|
||||
<< " capacity misses in the L2.\n\n"
|
||||
|
||||
<< " --sm-count=<int> "
|
||||
<< " Override the number of SMs. This is used to limit the number of " << end_of_line
|
||||
<< " during profiling. If this is set, profiling attempts to limit the sm_count " << end_of_line
|
||||
<< " to user-set value. This is not possible on all architectures and all kernel types. \n\n";
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user