v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -73,6 +73,7 @@ class EVTFrontendBase:
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0
self.imm_cnt = 0
self.pass_manager = EVTPassManager(
self.dag_ir,
@ -107,6 +108,13 @@ class EVTFrontendBase:
# Parse the input
self.parse(*args, **kwargs)
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if (self.cc >= 90):
if (self.dag_ir.out_degree("D") != 0):
raise RuntimeError(
f"On SM90 or higher, D is expected to be a output node with 0 users to "
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
# Run the passes
self.pass_manager()
# Set the epilogue type
@ -187,7 +195,8 @@ class EVTFrontendBase:
except:
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
name = f"imm_{value}".replace('.', '_')
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
self.imm_cnt += 1
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)

View File

@ -42,7 +42,7 @@ from cutlass_library import DataType
import cutlass
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass.backend.epilogue import relu
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass.backend.library import FunctionalOp
@ -72,10 +72,17 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
ast.Div: FunctionalOp.Divides,
"maximum": FunctionalOp.Maximum,
"minimum": FunctionalOp.Minimum,
"identity": identity.binding_type,
"relu": relu.binding_type,
"tanh": tanh.binding_type,
"sigmoid": sigmoid.binding_type,
"silu": silu.binding_type,
"hardswish": hardswish.binding_type,
"gelu": gelu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum)
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
"exp": FunctionalOp.Exp
}
return mapping[op]

View File

@ -38,7 +38,9 @@ import networkx as nx
from cutlass_library import DataType
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.library import ActivationOp
from cutlass.backend.utils import device_cc
@ -59,6 +61,8 @@ class DAGIR:
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
@ -79,7 +83,21 @@ class DAGIR:
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
self._graph.add_edge(src, dst, weight=weight)
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name, fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""

View File

@ -51,15 +51,19 @@ class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, layout_tag=None, is_constant=False) -> None:
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or layout_tag is None or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout) or (tensor)")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
@ -70,10 +74,13 @@ class Tensor:
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant

View File

@ -77,11 +77,12 @@ class PassDAG2Tree(EVTPassBase):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
lca = None
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
lca = None
topo_idx = -1
for item in common_items:
if lca is None:
@ -91,53 +92,74 @@ class PassDAG2Tree(EVTPassBase):
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
else:
raise NotImplementedError("No LCA found. Consider SplitTreeVisitor.")
# there is no common ancestor for all the parents, we pack all the reachable
# nodes into a single DAG node as a fallback. The lca should be the input node of
# one of the output nodes with out_degree = 0
potential_output_nodes = []
for node in node_to_fuse:
if self.dag_ir.out_degree(node) == 0:
potential_output_nodes.append(node)
if len(potential_output_nodes) == 0:
raise RuntimeError(f"No output node with out degree = 0 found.")
output_node = None
if (self.dag_ir.cc >= 90):
# For SM90, the lca should be the input node of D
if (not self.dag_ir.has_node("D")):
raise RuntimeError(f"D is not a node in the DAG IR.")
output_node = "D"
else:
output_node = potential_output_nodes[0]
if (output_node is None):
raise RuntimeError(f"No output node found.")
lca = self.dag_ir.get_all_inputs(output_node)[0]
node_to_fuse.remove(output_node)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}", subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca))
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree

View File

@ -118,6 +118,7 @@ class FunctionalOp(enum.Enum):
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
Exp = enum_auto()
FunctionalOpTag = {
@ -130,6 +131,7 @@ FunctionalOpTag = {
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
FunctionalOp.Exp: "cutlass::fast_exp_op",
}

View File

@ -52,4 +52,5 @@ from cutlass.epilogue.evt_ops import (
reshape,
maximum,
minimum,
exp
)

View File

@ -73,6 +73,12 @@ def minimum(x, y):
elif is_torch_tensor(x):
return torch.minimum(x, torch.tensor(y))
def exp(x):
if is_numpy_tensor(x):
return np.exp(x)
elif is_torch_tensor(x):
return torch.exp(x)
##############################################################################
# Layout manipulate nodes