v4.1 release
This commit is contained in:
@ -73,6 +73,7 @@ class EVTFrontendBase:
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
self.imm_cnt = 0
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir,
|
||||
@ -107,6 +108,13 @@ class EVTFrontendBase:
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if (self.cc >= 90):
|
||||
if (self.dag_ir.out_degree("D") != 0):
|
||||
raise RuntimeError(
|
||||
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
||||
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
@ -187,7 +195,8 @@ class EVTFrontendBase:
|
||||
except:
|
||||
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
||||
|
||||
name = f"imm_{value}".replace('.', '_')
|
||||
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
||||
self.imm_cnt += 1
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
|
||||
@ -42,7 +42,7 @@ from cutlass_library import DataType
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass.backend.epilogue import relu
|
||||
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass.backend.library import FunctionalOp
|
||||
|
||||
|
||||
@ -72,10 +72,17 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"identity": identity.binding_type,
|
||||
"relu": relu.binding_type,
|
||||
"tanh": tanh.binding_type,
|
||||
"sigmoid": sigmoid.binding_type,
|
||||
"silu": silu.binding_type,
|
||||
"hardswish": hardswish.binding_type,
|
||||
"gelu": gelu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
||||
"exp": FunctionalOp.Exp
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
|
||||
@ -38,7 +38,9 @@ import networkx as nx
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.library import ActivationOp
|
||||
from cutlass.backend.utils import device_cc
|
||||
|
||||
|
||||
@ -59,6 +61,8 @@ class DAGIR:
|
||||
|
||||
self.cc = cc
|
||||
|
||||
self.identity_counter = 0
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
@ -79,7 +83,21 @@ class DAGIR:
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
if self._graph.has_edge(src, dst):
|
||||
# The DiGraph doesn't support multiple edges between two nodes
|
||||
# We insert an identity node in such case as a workaround
|
||||
identity_name = f"autogen_identity_{self.identity_counter}"
|
||||
self.identity_counter += 1
|
||||
compute_node = ComputeNode(
|
||||
name=identity_name, fn=ActivationOp.Identity,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
self.add_edge(src, identity_name, 0)
|
||||
self.add_edge(identity_name, dst, weight)
|
||||
else:
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
|
||||
@ -51,15 +51,19 @@ class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
|
||||
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
|
||||
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception(f"Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
@ -70,10 +74,13 @@ class Tensor:
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
|
||||
@ -77,11 +77,12 @@ class PassDAG2Tree(EVTPassBase):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
|
||||
lca = None
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
lca = None
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
@ -91,53 +92,74 @@ class PassDAG2Tree(EVTPassBase):
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
|
||||
# there is no common ancestor for all the parents, we pack all the reachable
|
||||
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
||||
# one of the output nodes with out_degree = 0
|
||||
potential_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
if self.dag_ir.out_degree(node) == 0:
|
||||
potential_output_nodes.append(node)
|
||||
if len(potential_output_nodes) == 0:
|
||||
raise RuntimeError(f"No output node with out degree = 0 found.")
|
||||
|
||||
output_node = None
|
||||
if (self.dag_ir.cc >= 90):
|
||||
# For SM90, the lca should be the input node of D
|
||||
if (not self.dag_ir.has_node("D")):
|
||||
raise RuntimeError(f"D is not a node in the DAG IR.")
|
||||
output_node = "D"
|
||||
else:
|
||||
output_node = potential_output_nodes[0]
|
||||
|
||||
if (output_node is None):
|
||||
raise RuntimeError(f"No output node found.")
|
||||
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
||||
node_to_fuse.remove(output_node)
|
||||
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
|
||||
@ -118,6 +118,7 @@ class FunctionalOp(enum.Enum):
|
||||
Multiplies = enum_auto()
|
||||
MultiplyAdd = enum_auto()
|
||||
Plus = enum_auto()
|
||||
Exp = enum_auto()
|
||||
|
||||
|
||||
FunctionalOpTag = {
|
||||
@ -130,6 +131,7 @@ FunctionalOpTag = {
|
||||
FunctionalOp.Multiplies: "cutlass::multiplies",
|
||||
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
||||
FunctionalOp.Plus: "cutlass::plus",
|
||||
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -52,4 +52,5 @@ from cutlass.epilogue.evt_ops import (
|
||||
reshape,
|
||||
maximum,
|
||||
minimum,
|
||||
exp
|
||||
)
|
||||
|
||||
@ -73,6 +73,12 @@ def minimum(x, y):
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
|
||||
def exp(x):
|
||||
if is_numpy_tensor(x):
|
||||
return np.exp(x)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.exp(x)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
|
||||
Reference in New Issue
Block a user