diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index 4cc1edf0..442a708d 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -67,10 +67,10 @@ class EVTFrontendBase: "reshape": reshape } - def __init__(self, element_compute=DataType.f32, cc=None, additional_passes=[], **kwargs) -> None: - self.cc = cc if cc else device_cc() + def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None: + self.cc = cc self.element_compute = library_type(element_compute) - self.dag_ir = DAGIR(self.element_compute, self.cc) + self.dag_ir = DAGIR(self.cc, self.element_compute) self.compute_cnt = 0 self.layout_cnt = 0 diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index 0af934a6..14827812 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -47,8 +47,8 @@ from cutlass.backend.library import FunctionalOp class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): - def __init__(self, element_compute=DataType.f32, **kwargs): - super().__init__(element_compute, **kwargs) + def __init__(self, cc, element_compute=DataType.f32, **kwargs): + super().__init__(cc, element_compute, **kwargs) # Flags # If this state is True, visit_Constant returns values without creating imm node self.no_imm = False diff --git a/python/cutlass/backend/evt/ir/dag_ir.py b/python/cutlass/backend/evt/ir/dag_ir.py index ce8c3d64..16281d34 100644 --- a/python/cutlass/backend/evt/ir/dag_ir.py +++ b/python/cutlass/backend/evt/ir/dag_ir.py @@ -49,7 +49,7 @@ class DAGIR: In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node """ - def __init__(self, element_compute=DataType.f32, cc: int=None) -> None: + def __init__(self, cc, element_compute=DataType.f32) -> None: # The EVT DAGIR is managed through the nextworkX Digraph class self._graph = nx.DiGraph() @@ -57,7 +57,7 @@ class DAGIR: self.reduction_names = [] - self.cc = cc if cc else device_cc() + self.cc = cc # # IR manipulator diff --git a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py index 5783e9b0..91eb2054 100644 --- a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +++ b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py @@ -108,7 +108,7 @@ class PassDAG2Tree(EVTPassBase): # Create the subgraph subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) - subgraph = DAGIR() + subgraph = DAGIR(self.dag_ir.cc) for node in subgraph_.nodes: meta = deepcopy(self.dag_ir.get_node_meta(node)) if node not in node_to_fuse: diff --git a/python/cutlass/epilogue/epilogue.py b/python/cutlass/epilogue/epilogue.py index b1dcfa4f..76c75e20 100644 --- a/python/cutlass/epilogue/epilogue.py +++ b/python/cutlass/epilogue/epilogue.py @@ -43,7 +43,7 @@ code like the following for GEMM: plan.activation = cutlass.epilogue.relu """ -from cutlass.backend import epilogue +from cutlass.backend import epilogue, device_cc gelu = epilogue.gelu @@ -146,8 +146,10 @@ def trace(fn, example_tensors, **kwargs): """ if callable(fn): class EpilogueFunctor(PythonASTFrontend): - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, cc=None, **kwargs): + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) pass setattr(EpilogueFunctor, "__call__", staticmethod(fn))