CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -67,11 +67,12 @@ public:
|
||||
/// Problem structure obtained from problem space
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode;
|
||||
cutlass::library::GemmUniversalMode mode;
|
||||
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
|
||||
int64_t lda;
|
||||
int64_t ldb;
|
||||
int64_t ldc;
|
||||
@ -93,9 +94,16 @@ public:
|
||||
// Methods
|
||||
//
|
||||
|
||||
GemmProblem():
|
||||
GemmProblem():
|
||||
mode(library::GemmUniversalMode::kGemm),
|
||||
m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1),
|
||||
m(16),
|
||||
n(16),
|
||||
k(16),
|
||||
lda(0),
|
||||
ldb(0),
|
||||
ldc(0),
|
||||
split_k_slices(1),
|
||||
batch_count(1),
|
||||
raster_order(cutlass::library::RasterOrder::kHeuristic){ }
|
||||
|
||||
/// Parses the problem
|
||||
@ -117,7 +125,7 @@ public:
|
||||
ProblemSpace const &problem_space);
|
||||
};
|
||||
|
||||
/// Workspace used
|
||||
/// Workspace used
|
||||
struct GemmWorkspace {
|
||||
|
||||
DeviceAllocation *A;
|
||||
@ -150,7 +158,7 @@ public:
|
||||
// Methods
|
||||
//
|
||||
|
||||
GemmWorkspace():
|
||||
GemmWorkspace():
|
||||
A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { }
|
||||
};
|
||||
|
||||
@ -163,7 +171,7 @@ protected:
|
||||
/// GEMM problem obtained from problem space
|
||||
GemmProblem problem_;
|
||||
|
||||
/// Device memory allocations
|
||||
/// Device memory allocations
|
||||
GemmWorkspace gemm_workspace_;
|
||||
|
||||
/// CUTLASS parallel reduction operation to follow this* gemm operation
|
||||
@ -190,8 +198,8 @@ public:
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
virtual Status initialize_configuration(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
@ -199,8 +207,8 @@ public:
|
||||
|
||||
/// Initializes workspace
|
||||
virtual Status initialize_workspace(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
@ -208,7 +216,7 @@ public:
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
virtual bool verify_cutlass(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -217,8 +225,8 @@ public:
|
||||
|
||||
/// Measures performance results
|
||||
virtual bool profile(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
@ -229,13 +237,13 @@ protected:
|
||||
/// Initializes the performance result
|
||||
void initialize_result_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
library::GemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
bool verify_with_cublas_(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -244,7 +252,7 @@ protected:
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool verify_with_reference_(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
|
||||
@ -1493,7 +1493,6 @@ bool DeviceAllocation::block_compare_equal(
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -1633,7 +1632,7 @@ bool DeviceAllocation::block_compare_equal(
|
||||
capacity);
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported numeric type");
|
||||
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1662,7 +1661,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
|
||||
capacity,
|
||||
static_cast<float_e5m2_t>(epsilon),
|
||||
static_cast<float_e5m2_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareRelativelyEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -2089,8 +2087,12 @@ void DeviceAllocation::write_tensor_csv(
|
||||
write_tensor_csv_static_type<cutlass::complex<double> >(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kVoid:
|
||||
// Not dump anything as it is a empty tensor.
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported numeric type");
|
||||
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()) ) ;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2168,7 +2170,6 @@ void DeviceAllocation::fill(double val = 0.0) {
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
tensor_fill<float_e5m2_t>(*this, static_cast<float_e5m2_t>(val));
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
tensor_fill<half_t>(*this, static_cast<half_t>(val));
|
||||
break;
|
||||
@ -2254,7 +2255,7 @@ void DeviceAllocation::fill(double val = 0.0) {
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported numeric type");
|
||||
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ namespace profiler {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Ctor
|
||||
GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
OperationProfiler(
|
||||
options,
|
||||
library::OperationKind::kGemm,
|
||||
@ -73,7 +73,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
|
||||
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
||||
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
||||
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
|
||||
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
|
||||
},
|
||||
{ library::Provider::kCUBLAS}
|
||||
) {
|
||||
@ -119,7 +119,7 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
|
||||
|
||||
<< "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n"
|
||||
|
||||
|
||||
<< "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm \\ \n"
|
||||
<< " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n"
|
||||
@ -150,9 +150,9 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
library::GemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
|
||||
this->mode = library::GemmUniversalMode::kGemm;
|
||||
|
||||
|
||||
if (!arg_as_int(this->m, "m", problem_space, problem)) {
|
||||
// default value
|
||||
this->m = 1024;
|
||||
@ -162,17 +162,17 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
// default value
|
||||
this->n = 1024;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_int(this->k, "k", problem_space, problem)) {
|
||||
// default value
|
||||
this->k = 1024;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
|
||||
// default value
|
||||
this->split_k_mode = library::SplitKMode::kSerial;
|
||||
}
|
||||
|
||||
|
||||
this->mode = library::GemmUniversalMode::kGemm;
|
||||
if (this->split_k_mode == library::SplitKMode::kParallel) {
|
||||
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
|
||||
@ -182,7 +182,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
// default value
|
||||
this->split_k_slices = 1;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
|
||||
// default value
|
||||
this->batch_count = 1;
|
||||
@ -194,7 +194,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
// default value
|
||||
this->raster_order = library::RasterOrder::kHeuristic;
|
||||
}
|
||||
|
||||
|
||||
if (this->split_k_slices > 1 && this->batch_count > 1) {
|
||||
// At least one of these must be one
|
||||
return Status::kErrorInvalidProblem;
|
||||
@ -217,24 +217,24 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
}
|
||||
|
||||
if (!arg_as_scalar(
|
||||
this->alpha,
|
||||
operation_desc.element_epilogue,
|
||||
"alpha",
|
||||
problem_space,
|
||||
this->alpha,
|
||||
operation_desc.element_epilogue,
|
||||
"alpha",
|
||||
problem_space,
|
||||
problem)) {
|
||||
|
||||
if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_scalar(
|
||||
this->beta,
|
||||
operation_desc.element_epilogue,
|
||||
"beta",
|
||||
problem_space,
|
||||
this->beta,
|
||||
operation_desc.element_epilogue,
|
||||
"beta",
|
||||
problem_space,
|
||||
problem)) {
|
||||
|
||||
|
||||
if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
@ -327,7 +327,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
||||
set_argument(result, "batch_count", problem_space, batch_count);
|
||||
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
||||
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
||||
set_argument(result, "alpha", problem_space,
|
||||
library::lexical_cast(alpha, operation_desc.element_epilogue));
|
||||
|
||||
@ -339,14 +339,14 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
Status GemmOperationProfiler::initialize_configuration(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
library::GemmDescription const &operation_desc =
|
||||
library::GemmDescription const &operation_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
if (operation_desc.gemm_kind != library::GemmKind::kUniversal) {
|
||||
@ -383,7 +383,6 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
gemm_workspace_.arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
gemm_workspace_.arguments.raster_order = problem_.raster_order;
|
||||
|
||||
// initialize reduction operation for parallel splitKMode
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
if (!initialize_reduction_configuration_(operation, problem)) {
|
||||
@ -392,14 +391,14 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
}
|
||||
|
||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||
|
||||
|
||||
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
|
||||
}
|
||||
|
||||
/// Initializes the performance result
|
||||
void GemmOperationProfiler::initialize_result_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
library::GemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space) {
|
||||
|
||||
@ -451,7 +450,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_(
|
||||
);
|
||||
|
||||
auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);
|
||||
|
||||
|
||||
if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
|
||||
return false;
|
||||
}
|
||||
@ -465,7 +464,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_(
|
||||
|
||||
/// Initializes workspace
|
||||
Status GemmOperationProfiler::initialize_workspace(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -480,14 +479,14 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
}
|
||||
}
|
||||
|
||||
library::GemmDescription const &operation_desc =
|
||||
library::GemmDescription const &operation_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
// Compute the number of copies of the problem to avoid L2 camping.
|
||||
if (!options.profiling.workspace_count) {
|
||||
int64_t bytes = problem_.bytes(operation_desc);
|
||||
if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) {
|
||||
gemm_workspace_.problem_count =
|
||||
gemm_workspace_.problem_count =
|
||||
1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes);
|
||||
}
|
||||
else {
|
||||
@ -629,7 +628,7 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
bool GemmOperationProfiler::verify_cutlass(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -685,7 +684,7 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
}
|
||||
|
||||
results_.back().status = underlying_operation->run(
|
||||
&gemm_workspace_.arguments,
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
|
||||
@ -748,8 +747,8 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
#endif // #if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem);
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
// verification providers which are supported
|
||||
bool is_any_verification_run_passed = false;
|
||||
for (auto &m : results_.back().verification_map) {
|
||||
@ -788,7 +787,7 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
bool GemmOperationProfiler::verify_with_cublas_(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -798,13 +797,13 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
|
||||
#if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
library::GemmDescription const &gemm_desc =
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
//
|
||||
// Construct cuBLAS operators
|
||||
//
|
||||
|
||||
|
||||
CublasCreate handle;
|
||||
cublasStatus_t status = handle.get_cublas_create_status();
|
||||
|
||||
@ -817,8 +816,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
std::vector<cublasGemmAlgo_t> algorithms;
|
||||
|
||||
detail::select_cublas_algorithms(
|
||||
algorithms,
|
||||
options,
|
||||
algorithms,
|
||||
options,
|
||||
gemm_desc);
|
||||
|
||||
if (algorithms.empty()) {
|
||||
@ -849,8 +848,8 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
gemm_workspace_.arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
|
||||
detail::cublasGemmExDispatcher gemm_op(
|
||||
gemm_desc,
|
||||
detail::cublasGemmExDispatcher gemm_op(
|
||||
gemm_desc,
|
||||
gemm_workspace_.configuration,
|
||||
gemm_workspace_.arguments,
|
||||
algorithms.front()
|
||||
@ -884,7 +883,7 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
);
|
||||
|
||||
// Save workspace if incorrect
|
||||
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
|
||||
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
|
||||
results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) {
|
||||
|
||||
save_workspace(
|
||||
@ -909,14 +908,14 @@ bool GemmOperationProfiler::verify_with_cublas_(
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool GemmOperationProfiler::verify_with_reference_(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
library::GemmDescription const &gemm_desc =
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
//
|
||||
@ -1016,7 +1015,7 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
results_.back().status = status;
|
||||
|
||||
if (provider == library::Provider::kReferenceHost) {
|
||||
gemm_workspace_.Reference->copy_from_host(ptr_D);
|
||||
gemm_workspace_.Reference->copy_from_host(ptr_D);
|
||||
}
|
||||
|
||||
//
|
||||
@ -1031,7 +1030,7 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
);
|
||||
|
||||
// Save workspace if incorrect
|
||||
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
|
||||
if (options.verification.save_workspace == SaveWorkspace::kIncorrect &&
|
||||
results_.back().verification_map[provider] == Disposition::kIncorrect) {
|
||||
|
||||
save_workspace(
|
||||
@ -1050,7 +1049,7 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
|
||||
/// Measures performance results
|
||||
bool GemmOperationProfiler::profile(
|
||||
Options const &options,
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
@ -1131,7 +1130,7 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
Status status;
|
||||
|
||||
for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) {
|
||||
|
||||
|
||||
int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count;
|
||||
|
||||
gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx);
|
||||
@ -1184,7 +1183,7 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
|
||||
int iteration = 0;
|
||||
for (; iteration < Iterations; ++iteration) {
|
||||
|
||||
|
||||
// Iterate over copies of the problem in memory
|
||||
int workspace_idx = options.profiling.warmup_iterations + iteration;
|
||||
int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count;
|
||||
|
||||
Reference in New Issue
Block a user