add maximum support (#1833)
This commit is contained in:
@ -70,6 +70,8 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
ast.Sub: FunctionalOp.Minus,
|
||||
ast.Mult: FunctionalOp.Multiplies,
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"relu": relu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
|
||||
@ -49,5 +49,7 @@ from cutlass.epilogue.evt_ops import (
|
||||
multiply_add,
|
||||
sum,
|
||||
permute,
|
||||
reshape
|
||||
reshape,
|
||||
maximum,
|
||||
minimum,
|
||||
)
|
||||
|
||||
@ -59,6 +59,17 @@ def max(x, dim):
|
||||
elif is_torch_tensor(x):
|
||||
return torch.amax(x, dim)
|
||||
|
||||
def maximum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.maximum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.maximum(x, torch.tensor(y))
|
||||
|
||||
def minimum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.minimum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
|
||||
Reference in New Issue
Block a user