Make cc a positional argument (#2249)
This commit is contained in:
@ -67,10 +67,10 @@ class EVTFrontendBase:
|
||||
"reshape": reshape
|
||||
}
|
||||
|
||||
def __init__(self, element_compute=DataType.f32, cc=None, additional_passes=[], **kwargs) -> None:
|
||||
self.cc = cc if cc else device_cc()
|
||||
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
|
||||
self.cc = cc
|
||||
self.element_compute = library_type(element_compute)
|
||||
self.dag_ir = DAGIR(self.element_compute, self.cc)
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
|
||||
|
||||
@ -47,8 +47,8 @@ from cutlass.backend.library import FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
def __init__(self, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(element_compute, **kwargs)
|
||||
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(cc, element_compute, **kwargs)
|
||||
# Flags
|
||||
# If this state is True, visit_Constant returns values without creating imm node
|
||||
self.no_imm = False
|
||||
|
||||
@ -49,7 +49,7 @@ class DAGIR:
|
||||
|
||||
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
||||
"""
|
||||
def __init__(self, element_compute=DataType.f32, cc: int=None) -> None:
|
||||
def __init__(self, cc, element_compute=DataType.f32) -> None:
|
||||
# The EVT DAGIR is managed through the nextworkX Digraph class
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
@ -57,7 +57,7 @@ class DAGIR:
|
||||
|
||||
self.reduction_names = []
|
||||
|
||||
self.cc = cc if cc else device_cc()
|
||||
self.cc = cc
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
|
||||
@ -108,7 +108,7 @@ class PassDAG2Tree(EVTPassBase):
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR()
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
|
||||
@ -43,7 +43,7 @@ code like the following for GEMM:
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
"""
|
||||
|
||||
from cutlass.backend import epilogue
|
||||
from cutlass.backend import epilogue, device_cc
|
||||
|
||||
|
||||
gelu = epilogue.gelu
|
||||
@ -146,8 +146,10 @@ def trace(fn, example_tensors, **kwargs):
|
||||
"""
|
||||
if callable(fn):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, cc=None, **kwargs):
|
||||
if not cc:
|
||||
cc = device_cc()
|
||||
super().__init__(cc, **kwargs)
|
||||
pass
|
||||
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user