Make cc a positional argument (#2249)

This commit is contained in:
Michael Lazos
2025-04-30 13:09:25 -07:00
committed by GitHub
parent fe75ead92e
commit b3ce7e12b7
5 changed files with 13 additions and 11 deletions

View File

@ -67,10 +67,10 @@ class EVTFrontendBase:
"reshape": reshape
}
def __init__(self, element_compute=DataType.f32, cc=None, additional_passes=[], **kwargs) -> None:
self.cc = cc if cc else device_cc()
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
self.cc = cc
self.element_compute = library_type(element_compute)
self.dag_ir = DAGIR(self.element_compute, self.cc)
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0

View File

@ -47,8 +47,8 @@ from cutlass.backend.library import FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
def __init__(self, element_compute=DataType.f32, **kwargs):
super().__init__(element_compute, **kwargs)
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
super().__init__(cc, element_compute, **kwargs)
# Flags
# If this state is True, visit_Constant returns values without creating imm node
self.no_imm = False

View File

@ -49,7 +49,7 @@ class DAGIR:
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, element_compute=DataType.f32, cc: int=None) -> None:
def __init__(self, cc, element_compute=DataType.f32) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
@ -57,7 +57,7 @@ class DAGIR:
self.reduction_names = []
self.cc = cc if cc else device_cc()
self.cc = cc
#
# IR manipulator

View File

@ -108,7 +108,7 @@ class PassDAG2Tree(EVTPassBase):
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR()
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:

View File

@ -43,7 +43,7 @@ code like the following for GEMM:
plan.activation = cutlass.epilogue.relu
"""
from cutlass.backend import epilogue
from cutlass.backend import epilogue, device_cc
gelu = epilogue.gelu
@ -146,8 +146,10 @@ def trace(fn, example_tensors, **kwargs):
"""
if callable(fn):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, cc=None, **kwargs):
if not cc:
cc = device_cc()
super().__init__(cc, **kwargs)
pass
setattr(EpilogueFunctor, "__call__", staticmethod(fn))