CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@ -67,11 +67,12 @@ public:
/// Problem structure obtained from problem space
struct GemmProblem {
cutlass::library::GemmUniversalMode mode;
cutlass::library::GemmUniversalMode mode;
int64_t m;
int64_t n;
int64_t k;
int64_t lda;
int64_t ldb;
int64_t ldc;
@ -93,9 +94,16 @@ public:
// Methods
//
GemmProblem():
GemmProblem():
mode(library::GemmUniversalMode::kGemm),
m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1),
m(16),
n(16),
k(16),
lda(0),
ldb(0),
ldc(0),
split_k_slices(1),
batch_count(1),
raster_order(cutlass::library::RasterOrder::kHeuristic){ }
/// Parses the problem
@ -117,7 +125,7 @@ public:
ProblemSpace const &problem_space);
};
/// Workspace used
/// Workspace used
struct GemmWorkspace {
DeviceAllocation *A;
@ -150,7 +158,7 @@ public:
// Methods
//
GemmWorkspace():
GemmWorkspace():
A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { }
};
@ -163,7 +171,7 @@ protected:
/// GEMM problem obtained from problem space
GemmProblem problem_;
/// Device memory allocations
/// Device memory allocations
GemmWorkspace gemm_workspace_;
/// CUTLASS parallel reduction operation to follow this* gemm operation
@ -190,8 +198,8 @@ public:
/// Extracts the problem dimensions
virtual Status initialize_configuration(
Options const &options,
PerformanceReport &report,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
@ -199,8 +207,8 @@ public:
/// Initializes workspace
virtual Status initialize_workspace(
Options const &options,
PerformanceReport &report,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
@ -208,7 +216,7 @@ public:
/// Verifies CUTLASS against references
virtual bool verify_cutlass(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -217,8 +225,8 @@ public:
/// Measures performance results
virtual bool profile(
Options const &options,
PerformanceReport &report,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
@ -229,13 +237,13 @@ protected:
/// Initializes the performance result
void initialize_result_(
PerformanceResult &result,
Options const &options,
Options const &options,
library::GemmDescription const &operation_desc,
ProblemSpace const &problem_space);
/// Verifies CUTLASS against references
bool verify_with_cublas_(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -244,7 +252,7 @@ protected:
/// Verifies CUTLASS against host and device references
bool verify_with_reference_(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,

View File

@ -1493,7 +1493,6 @@ bool DeviceAllocation::block_compare_equal(
reinterpret_cast<float_e5m2_t const *>(ptr_A),
reinterpret_cast<float_e5m2_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kF16:
return reference::device::BlockCompareEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -1633,7 +1632,7 @@ bool DeviceAllocation::block_compare_equal(
capacity);
default:
throw std::runtime_error("Unsupported numeric type");
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type));
}
}
@ -1662,7 +1661,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
capacity,
static_cast<float_e5m2_t>(epsilon),
static_cast<float_e5m2_t>(nonzero_floor));
case library::NumericTypeID::kF16:
return reference::device::BlockCompareRelativelyEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -2089,8 +2087,12 @@ void DeviceAllocation::write_tensor_csv(
write_tensor_csv_static_type<cutlass::complex<double> >(out, *this);
break;
case library::NumericTypeID::kVoid:
// Not dump anything as it is a empty tensor.
break;
default:
throw std::runtime_error("Unsupported numeric type");
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()) ) ;
}
}
@ -2168,7 +2170,6 @@ void DeviceAllocation::fill(double val = 0.0) {
case library::NumericTypeID::kFE5M2:
tensor_fill<float_e5m2_t>(*this, static_cast<float_e5m2_t>(val));
break;
case library::NumericTypeID::kF16:
tensor_fill<half_t>(*this, static_cast<half_t>(val));
break;
@ -2254,7 +2255,7 @@ void DeviceAllocation::fill(double val = 0.0) {
break;
default:
throw std::runtime_error("Unsupported numeric type");
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()));
}
}

View File

@ -55,7 +55,7 @@ namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Ctor
GemmOperationProfiler::GemmOperationProfiler(Options const &options):
GemmOperationProfiler::GemmOperationProfiler(Options const &options):
OperationProfiler(
options,
library::OperationKind::kGemm,
@ -73,7 +73,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
},
{ library::Provider::kCUBLAS}
) {
@ -119,7 +119,7 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
<< "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n"
<< " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n"
<< "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n"
<< " $ cutlass_profiler --operation=Gemm \\ \n"
<< " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n"
@ -150,9 +150,9 @@ Status GemmOperationProfiler::GemmProblem::parse(
library::GemmDescription const &operation_desc,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
this->mode = library::GemmUniversalMode::kGemm;
if (!arg_as_int(this->m, "m", problem_space, problem)) {
// default value
this->m = 1024;
@ -162,17 +162,17 @@ Status GemmOperationProfiler::GemmProblem::parse(
// default value
this->n = 1024;
}
if (!arg_as_int(this->k, "k", problem_space, problem)) {
// default value
this->k = 1024;
}
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
// default value
this->split_k_mode = library::SplitKMode::kSerial;
}
this->mode = library::GemmUniversalMode::kGemm;
if (this->split_k_mode == library::SplitKMode::kParallel) {
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
@ -182,7 +182,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
// default value
this->split_k_slices = 1;
}
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
// default value
this->batch_count = 1;
@ -194,7 +194,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
// default value
this->raster_order = library::RasterOrder::kHeuristic;
}
if (this->split_k_slices > 1 && this->batch_count > 1) {
// At least one of these must be one
return Status::kErrorInvalidProblem;
@ -217,24 +217,24 @@ Status GemmOperationProfiler::GemmProblem::parse(
}
if (!arg_as_scalar(
this->alpha,
operation_desc.element_epilogue,
"alpha",
problem_space,
this->alpha,
operation_desc.element_epilogue,
"alpha",
problem_space,
problem)) {
if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) {
return Status::kErrorInternal;
}
}
if (!arg_as_scalar(
this->beta,
operation_desc.element_epilogue,
"beta",
problem_space,
this->beta,
operation_desc.element_epilogue,
"beta",
problem_space,
problem)) {
if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) {
return Status::kErrorInternal;
}
@ -327,7 +327,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "split_k_slices", problem_space, split_k_slices);
set_argument(result, "batch_count", problem_space, batch_count);
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
set_argument(result, "alpha", problem_space,
library::lexical_cast(alpha, operation_desc.element_epilogue));
@ -339,14 +339,14 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
/// Extracts the problem dimensions
Status GemmOperationProfiler::initialize_configuration(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
library::GemmDescription const &operation_desc =
library::GemmDescription const &operation_desc =
static_cast<library::GemmDescription const &>(operation->description());
if (operation_desc.gemm_kind != library::GemmKind::kUniversal) {
@ -383,7 +383,6 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_.arguments.beta = problem_.beta.data();
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
gemm_workspace_.arguments.raster_order = problem_.raster_order;
// initialize reduction operation for parallel splitKMode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!initialize_reduction_configuration_(operation, problem)) {
@ -392,14 +391,14 @@ Status GemmOperationProfiler::initialize_configuration(
}
initialize_result_(this->model_result_, options, operation_desc, problem_space);
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
}
/// Initializes the performance result
void GemmOperationProfiler::initialize_result_(
PerformanceResult &result,
Options const &options,
Options const &options,
library::GemmDescription const &operation_desc,
ProblemSpace const &problem_space) {
@ -451,7 +450,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_(
);
auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);
if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
return false;
}
@ -465,7 +464,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_(
/// Initializes workspace
Status GemmOperationProfiler::initialize_workspace(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -480,14 +479,14 @@ Status GemmOperationProfiler::initialize_workspace(
}
}
library::GemmDescription const &operation_desc =
library::GemmDescription const &operation_desc =
static_cast<library::GemmDescription const &>(operation->description());
// Compute the number of copies of the problem to avoid L2 camping.
if (!options.profiling.workspace_count) {
int64_t bytes = problem_.bytes(operation_desc);
if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) {
gemm_workspace_.problem_count =
gemm_workspace_.problem_count =
1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes);
}
else {
@ -629,7 +628,7 @@ Status GemmOperationProfiler::initialize_workspace(
/// Verifies CUTLASS against references
bool GemmOperationProfiler::verify_cutlass(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -685,7 +684,7 @@ bool GemmOperationProfiler::verify_cutlass(
}
results_.back().status = underlying_operation->run(
&gemm_workspace_.arguments,
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
@ -748,8 +747,8 @@ bool GemmOperationProfiler::verify_cutlass(
#endif // #if CUTLASS_ENABLE_CUBLAS
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem);
// Update disposition to worst case verification outcome among all
// Update disposition to worst case verification outcome among all
// verification providers which are supported
bool is_any_verification_run_passed = false;
for (auto &m : results_.back().verification_map) {
@ -788,7 +787,7 @@ bool GemmOperationProfiler::verify_cutlass(
/// Verifies CUTLASS against references
bool GemmOperationProfiler::verify_with_cublas_(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -798,13 +797,13 @@ bool GemmOperationProfiler::verify_with_cublas_(
#if CUTLASS_ENABLE_CUBLAS
library::GemmDescription const &gemm_desc =
library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const &>(operation->description());
//
// Construct cuBLAS operators
//
CublasCreate handle;
cublasStatus_t status = handle.get_cublas_create_status();
@ -817,8 +816,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
std::vector<cublasGemmAlgo_t> algorithms;
detail::select_cublas_algorithms(
algorithms,
options,
algorithms,
options,
gemm_desc);
if (algorithms.empty()) {
@ -849,8 +848,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
gemm_workspace_.arguments.beta = problem_.beta.data();
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
detail::cublasGemmExDispatcher gemm_op(
gemm_desc,
detail::cublasGemmExDispatcher gemm_op(
gemm_desc,
gemm_workspace_.configuration,
gemm_workspace_.arguments,
algorithms.front()
@ -884,7 +883,7 @@ bool GemmOperationProfiler::verify_with_cublas_(
);
// Save workspace if incorrect
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) {
save_workspace(
@ -909,14 +908,14 @@ bool GemmOperationProfiler::verify_with_cublas_(
/// Verifies CUTLASS against host and device references
bool GemmOperationProfiler::verify_with_reference_(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
library::GemmDescription const &gemm_desc =
library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const &>(operation->description());
//
@ -1016,7 +1015,7 @@ bool GemmOperationProfiler::verify_with_reference_(
results_.back().status = status;
if (provider == library::Provider::kReferenceHost) {
gemm_workspace_.Reference->copy_from_host(ptr_D);
gemm_workspace_.Reference->copy_from_host(ptr_D);
}
//
@ -1031,7 +1030,7 @@ bool GemmOperationProfiler::verify_with_reference_(
);
// Save workspace if incorrect
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
results_.back().verification_map[provider] == Disposition::kIncorrect) {
save_workspace(
@ -1050,7 +1049,7 @@ bool GemmOperationProfiler::verify_with_reference_(
/// Measures performance results
bool GemmOperationProfiler::profile(
Options const &options,
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
@ -1131,7 +1130,7 @@ Status GemmOperationProfiler::profile_cutlass_(
Status status;
for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) {
int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count;
gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx);
@ -1184,7 +1183,7 @@ Status GemmOperationProfiler::profile_cutlass_(
int iteration = 0;
for (; iteration < Iterations; ++iteration) {
// Iterate over copies of the problem in memory
int workspace_idx = options.profiling.warmup_iterations + iteration;
int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count;