Support ComputeFn where output type differs from input type (#1771)
This is useful for e.g. function taking in 2 float inputs and turn them to complex
This commit is contained in:
@ -181,14 +181,20 @@ public:
|
||||
},
|
||||
[&] (auto&&... cvt_frg_inputs) {
|
||||
using ComputeOutput = ComputeFn<Array<ElementCompute, FragmentSize>>;
|
||||
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize, RoundStyle>;
|
||||
ComputeOutput compute_output{};
|
||||
ConvertOutput convert_output{};
|
||||
|
||||
if constexpr (cute::is_same_v<Arguments, EmptyArguments>) {
|
||||
using ElementComputeOutput =
|
||||
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs...))>::Element;
|
||||
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
|
||||
ConvertOutput convert_output{};
|
||||
return convert_output(compute_output(cvt_frg_inputs...));
|
||||
}
|
||||
else {
|
||||
using ElementComputeOutput =
|
||||
typename cute::remove_cvref_t<decltype(compute_output(cvt_frg_inputs..., params))>::Element;
|
||||
using ConvertOutput = NumericArrayConverter<ElementOutput, ElementComputeOutput, FragmentSize, RoundStyle>;
|
||||
ConvertOutput convert_output{};
|
||||
return convert_output(compute_output(cvt_frg_inputs..., params));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user