fix split_k_mode and add reduction kernel for f16 input/accum/output (#896)

This commit is contained in:
Manish Gupta
2023-03-30 12:31:08 -07:00
committed by GitHub
parent bc36122c3f
commit 660a05f581
4 changed files with 41 additions and 5 deletions

View File

@ -42,6 +42,7 @@ namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////
// CUTLASS Reduction Instances //
///////////////////////////////////////////////////////////////////////////////////////////////
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
//
void initialize_all_reduction_op(Manifest &manifest) {
initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);

View File

@ -43,6 +43,40 @@ namespace library {
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {
using ElementWorkspace = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementCompute = cutlass::half_t;
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
typename EpilogueOutputOp::ElementAccumulator,
EpilogueOutputOp::kCount
>;
using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>
>;
manifest.append(new ReductionOperation<
Operation_reduce_add_linear_combination_f16_f16_f16>(
"reduce_add_linear_combination_f16_f16_f16"
));
}
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
using ElementWorkspace = float;