Support ComputeFn where output type differs from input type (#1771)

This is useful for e.g. function taking in 2 float inputs and turn them to complex
This commit is contained in:
Tri Dao
2024-09-05 20:25:03 -07:00
committed by GitHub
parent 82f5075946
commit 323c8170bf

View File

@ -181,14 +181,20 @@ public:
},
[&] (auto&&... cvt_frg_inputs) {
using ComputeOutput = ComputeFn<Array<ElementCompute, FragmentSize>>;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
ComputeOutput compute_output{};
ConvertOutput convert_output{};
if constexpr (cute::is_same_v<Arguments, EmptyArguments>) {
using ElementComputeOutput =
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs...))>::Element;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
ConvertOutput convert_output{};
return convert_output(compute_output(cvt_frg_inputs...));
}
else {
using ElementComputeOutput =
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs..., params))>::Element;
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
ConvertOutput convert_output{};
return convert_output(compute_output(cvt_frg_inputs..., params));
}
}