add maximum support (#1833)
This commit is contained in:
@ -95,6 +95,29 @@ class TestEVTCompute(EVTTestCaseBase):
|
||||
result_keys = ["D"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
def test_func_call2(self):
|
||||
"""
|
||||
Test Function call
|
||||
"""
|
||||
|
||||
def evt_func_call2(accum, C, alpha, beta):
|
||||
D = maximum(alpha * accum + beta * C, 0.0)
|
||||
return D
|
||||
|
||||
for m, n, k, l in self.get_problem_sizes(8):
|
||||
example_inputs = {
|
||||
"accum": self.fake_tensor(self.element, (l, m, n)),
|
||||
"C": self.fake_tensor(self.element, (l, m, n)),
|
||||
"alpha": 1.5,
|
||||
"beta": 0.5,
|
||||
"D": self.fake_tensor(self.element, (l, m, n))
|
||||
}
|
||||
|
||||
launcher = EVTTestBed(self.element, evt_func_call2, example_inputs)
|
||||
input_keys = ["C", "alpha", "beta"]
|
||||
result_keys = ["D"]
|
||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user