fix split_k_mode and add reduction kernel for f16 input/accum/output (#896)
This commit is contained in:
@ -42,6 +42,7 @@ namespace library {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// CUTLASS Reduction Instances //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
|
||||
@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
||||
//
|
||||
void initialize_all_reduction_op(Manifest &manifest) {
|
||||
|
||||
initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
|
||||
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
|
||||
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
|
||||
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);
|
||||
|
||||
@ -43,6 +43,40 @@ namespace library {
|
||||
|
||||
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
|
||||
|
||||
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {
|
||||
|
||||
using ElementWorkspace = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
||||
ElementAccumulator,
|
||||
typename EpilogueOutputOp::ElementAccumulator,
|
||||
EpilogueOutputOp::kCount
|
||||
>;
|
||||
|
||||
using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
|
||||
cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
||||
EpilogueOutputOp,
|
||||
ReductionOp
|
||||
>
|
||||
>;
|
||||
|
||||
manifest.append(new ReductionOperation<
|
||||
Operation_reduce_add_linear_combination_f16_f16_f16>(
|
||||
"reduce_add_linear_combination_f16_f16_f16"
|
||||
));
|
||||
}
|
||||
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
||||
|
||||
using ElementWorkspace = float;
|
||||
|
||||
Reference in New Issue
Block a user