################################################################################################# # # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Base & visitor classes of DAGIR Nodes """ import ctypes from re import sub from cutlass_library import LayoutType from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple from cutlass_cppgen.backend.evt.ir.tensor import Tensor class TupleEmitter: """ Emit the cute tuple to C++ code """ def __init__(self, stride_dtype): self.stride_dtype = stride_dtype def emit(self, py_tuple): if isinstance(py_tuple, int): if py_tuple in [0, 1]: return f"cute::Int<{py_tuple}>" else: return f"{self.stride_dtype}" elif isinstance(py_tuple, tuple): decl = "cute::Stride<" for item in py_tuple: decl += self.emit(item) + ", " return decl[:-2] + ">" else: raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") class ImplBase: """ Base class for Node Implementation """ def __init__(self, node) -> None: self.node = node self.name = node.name self.tensor = node.tensor self._type_decl = None self.tuple_emitter = TupleEmitter("int64_t") @property def stride_dtype(self): return self.tuple_emitter.stride_dtype @stride_dtype.setter def stride_dtype(self, stride_dtype): self.tuple_emitter.stride_dtype = stride_dtype @staticmethod def match(node, problem_size: tuple): """ Match function used in get_underlying_impl """ raise NotImplementedError(f"The `match` function is not defined.") @property def argument_type(self): """ Default class for Argument Type """ class _Argument(ctypes.Structure): _fields_ = [] def __init__(self, *args, **kwargs) -> None: pass return _Argument @property def name_camel(self) -> str: """ Return the CamelCase name. """ return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") @property def stride_mnl(self): """ Typename StrideMNL """ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) return self.tuple_emitter.emit(stride) def get_non_constant_stride(self, py_tuple): if isinstance(py_tuple, int): if py_tuple not in [0, 1]: return py_tuple else: return None non_constant_stride = [] for item in py_tuple: item_out = self.get_non_constant_stride(item) if item_out: non_constant_stride.append(item_out) return tuple(non_constant_stride) def get_stride_mnl(self): """ Get the non-zero stride mnl. This is used in argument construction """ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) return stride def get_smem_size(self, *args, **kwargs): """ Get the shared memory size and alignment of current node """ return (0, 1) class NoOpImpl(ImplBase): """ The NoOpImpl does nothing but forward its input to users """ def __init__(self, node) -> None: super().__init__(node) @staticmethod def match(node, problem_size: tuple): if node.op == "store": # Store that is not output is a No OP return not node.is_output class NodeBase: """ Base class of DAG Node """ def __init__(self, name: str) -> None: self.name = name self.underlying_impl = None self._tensor = None # Whether the node is disabled for emit self.disabled = False @property def name_camel(self) -> str: """ Return the CamelCase name. """ return self.underlying_impl.name_camel @property def tensor(self) -> Tensor: """ Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) """ return self._tensor @tensor.setter def tensor(self, kwargs): """ Setting the tensor """ self._tensor = Tensor(**kwargs) # # Helper functions for type/shape propagation # def shape_propagation(self, input_node_metas): """ Infer shape from input nodes General Broadcasting Rules from NumPy When operating on two arrays, we compare their shapes element-wise. It starts with the trailing (i.e. rightmost) dimension and works its way left. Two dimensions are compatible when 1. they are equal 2. one of them is 1 """ if self._tensor is not None: return shape = None for src in input_node_metas: src_shape = src.tensor.shape if shape is None: shape = src_shape else: len_difference = len(shape) - len(src_shape) if len_difference > 0: for _ in range(len_difference): src_shape = [1, ] + list(src_shape) elif len_difference < 0: for _ in range(-len_difference): shape = [1, ] + list(shape) broadcasted_shape = [] # Infer broadcast shape for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)): if shape_dim == 1: broadcasted_shape = [src_dim, ] + list(broadcasted_shape) elif src_dim == 1: broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) elif shape_dim == src_dim: broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) else: error_msg = "Dimension mismatch between " for src_ in input_node_metas: error_msg += f"{src_.name}{src_.tensor.shape}, " error_msg = error_msg[:-2] + "." raise RuntimeError(error_msg) shape = tuple(broadcasted_shape) self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor) def type_propagation(self, *args, **kwargs): """ Each node is associated with two data types: `element` and `element_output`. The `element_output` is the type of return array of the node. The `element` has specific meaning for different node types. * Load Node: data type of tensor in gmem * Compute Node: element compute * Store Node: data type of tensor in gmem This function must be overloaded in the derived classes """ raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}") def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): """ Propagate the broadcast in the reversed topological order. For example: C[l, m, n] = A[m, 1] + B[l, m, n] After the broadcast propagation, it will be come C[l, m, n] = A[l, m, n] + B[l, m, n] and each tensor will have a proper stride accessing the underlying tensor """ if self.tensor is None: raise RuntimeError(f"The tensor of node {self.name} is unknown.") for child in input_node_metas: child.tensor.broadcast(self.tensor.shape) def get_underlying_impl(self, problem_size: tuple): """ Get the underlying implementation of the current node. """ if self.tensor is None: raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.") for impl in self.possible_impls: if impl.match(self, problem_size): self.underlying_impl = impl(self) break if self.underlying_impl is None: raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.") # # Visitor Nodes & Impls # class TopoVisitorImpl(ImplBase): """ Impl for topological visitor """ def __init__(self, node) -> None: super().__init__(node.output_node) self.name = node.name self.element_output = node.output_node.element_output class TopoVisitorNode(NodeBase): def __init__(self, name: str, subgraph, output_node) -> None: super().__init__(name) self.subgraph = subgraph self.output_node = output_node self.op = "dag" self.underlying_impl = TopoVisitorImpl(self)