################################################################################################# # # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Layout manipulation nodes and implementations The layout Nodes change the layout of intermediate nodes in epilogue visitor graph """ from copy import deepcopy from cutlass_library import LayoutType from pycute import product, flatten import cutlass_cppgen from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list from cutlass_cppgen.backend.evt.ir.node import NodeBase from cutlass_cppgen.backend.evt.ir.tensor import Tensor class PermutationImpl: """ Detailed implementation and helper functions for permutation """ def __init__(self, node) -> None: assert "indices" in node.kwargs.keys() self.indices = list(node.kwargs["indices"]) self.inverse_indices = self.get_inverse_indices(self.indices) def get_inverse_impl(self): inverse_impl = deepcopy(self) inverse_impl.indices = self.inverse_indices inverse_impl.inverse_indices = self.indices return inverse_impl def update(self, shape): num_dim = len(shape) indices = self.indices num_old_dim = len(indices) # Add offset for i, idx in enumerate(indices): indices[i] = idx + num_dim - num_old_dim # Add broadcast dims for i in range(num_dim - num_old_dim): indices = [i,] + indices self.indices = indices self.inverse_indices = self.get_inverse_indices(self.indices) def get_inverse_indices(self, indices): """ Get the indices for inverse permutation """ num_dim = len(indices) inverse_indices = [0] * num_dim for i in range(num_dim): inverse_indices[indices[i]] = i return inverse_indices def shape_propagation(self, input_node_meta): input_shape = input_node_meta.tensor.shape output_shape = tuple([input_shape[idx] for idx in self.indices]) return output_shape def broadcast(self, shape, node_meta: NodeBase): """ Broadcast the inputs based on current shape """ self.update(shape) inverse_shape = tuple([shape[idx] for idx in self.inverse_indices]) node_meta.tensor.broadcast(inverse_shape) def apply_to_user(self, usr_meta: NodeBase): """ Propagate the permutation to the users of the current nodes """ usr_meta.tensor.permute(self.inverse_indices) if hasattr(usr_meta, "store_tensor"): if usr_meta.store_tensor is not None: usr_meta.store_tensor.permute(self.inverse_indices) def apply_to_input(self, input_meta: NodeBase): """ Propagate the permutation to inputs of the current nodes """ input_meta.tensor.permute(self.indices) if hasattr(input_meta, "store_tensor"): if input_meta.store_tensor is not None: input_meta.store_tensor.permute(self.indices) class ReshapeImpl: """ Detailed implementation and helper functions for reshape """ def __init__(self, node) -> None: self.node = node assert "new_shape" in node.kwargs.keys() self.output_shape = _list_to_tuple(node.kwargs["new_shape"]) def get_inverse_impl(self): inverse_impl = deepcopy(self) inverse_impl.output_shape = self.input_shape inverse_impl.input_shape = self.output_shape return inverse_impl def shape_propagation(self, input_node_meta): self.input_shape = input_node_meta.tensor.shape return _list_to_tuple(self.output_shape) def broadcast(self, shape, node_meta: NodeBase): """ Broadcast the inputs based on current shape. """ # Step 1: infer split flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape)) split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape) split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape) # broadcast shape -> split_output_shape -> flatten_split_shape if len(shape) - len(split_output_shape) > 0: for _ in range(len(shape) - len(split_output_shape)): split_output_shape = [1,] + split_output_shape flatten_split_shape = [1,] + flatten_split_shape split_input_shape = [1,] + split_input_shape broadcast_factor = [] for dim, old_dim in zip(shape, split_output_shape): if not isinstance(dim, list): dim = [dim,] if not isinstance(old_dim, list): old_dim = [old_dim,] if product(tuple(dim)) == product(tuple(old_dim)): broadcast_factor += [1] * len(old_dim) elif product(tuple(old_dim)) == 1: assert len(dim) == 1 broadcast_factor.append(dim[0]) else: raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}") # flatten_split_shape -> split_input_shape factor_idx = 0 broadcast_split_input_shape = [] for dim in split_input_shape: if isinstance(dim, list): new_dim = [] for d in dim: new_dim.append(d * broadcast_factor[factor_idx]) factor_idx += 1 broadcast_split_input_shape.append(new_dim) else: broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx]) factor_idx += 1 broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape) node_meta.tensor.reshape(_list_to_tuple(split_input_shape)) node_meta.tensor.broadcast(broadcast_split_input_shape) # Last reshape op to clean up broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape]) node_meta.tensor.reshape(broadcast_input_shape) # Update the input shape and output shape self.input_shape = _list_to_tuple(node_meta.tensor.shape) self.output_shape = _list_to_tuple(shape) def apply_to_user(self, user_meta: NodeBase): """ Propagate the reshape to user nodes """ user_meta.tensor.reshape(tuple(self.input_shape)) if hasattr(user_meta, "store_tensor"): if user_meta.store_tensor is not None: user_meta.store_tensor.reshape(tuple(self.input_shape)) def apply_to_input(self, input_meta: NodeBase): """ Propagate the reshape to input nodes """ input_meta.tensor.reshape(tuple(self.output_shape)) if hasattr(input_meta, "store_tensor"): if input_meta.store_tensor is not None: input_meta.store_tensor.reshape(tuple(self.output_shape)) # # Helper functions # def infer_split(self, input_shape, output_shape): """ Infer the flatten splitted shape that can be merged to both input_shape and output_shape """ input_shape = _tuple_to_list(input_shape) output_shape = _tuple_to_list(output_shape) if len(input_shape) == 0 and len(output_shape) == 0: return [] if len(input_shape) == 0: if product(tuple(output_shape)) != 1: raise ValueError("Invalid reshape size") else: return output_shape if len(output_shape) == 0: if product(tuple(input_shape)) != 1: raise ValueError("Invalid reshape size") else: return input_shape # This is done recursively by only process the last dimension at each time old_dim = input_shape[-1] new_dim = output_shape[-1] # Exact match if old_dim == new_dim: return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,] # Needs split if old_dim > new_dim and old_dim % new_dim == 0: residual = old_dim // new_dim return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,] # Needs merge if old_dim < new_dim and new_dim % old_dim == 0: residual = new_dim // old_dim return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,] raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}") def infer_merge(self, flatten_shape, shape): flatten_shape = _tuple_to_list(flatten_shape) shape = _tuple_to_list(shape) idx_flat = len(flatten_shape) - 1 merged_shape = [] for dim in reversed(shape): # Exact match if dim == flatten_shape[idx_flat]: merged_shape.append(dim) idx_flat -= 1 # need group elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: residual = dim group = [] while(residual > 1): group.append(flatten_shape[idx_flat]) residual = residual // flatten_shape[idx_flat] idx_flat -= 1 merged_shape.append(group[::-1]) else: raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") return merged_shape[::-1] class LayoutNode(NodeBase): """ Layout manipulation nodes """ fn_to_impl = { "permute": PermutationImpl, "reshape": ReshapeImpl } def __init__(self, name: str, fn, kwargs: dict) -> None: super().__init__(name) self.op = "layout" self.fn = fn self.kwargs = kwargs self.underlying_impl = self.fn_to_impl[self.fn.__name__](self) def get_inverse_node(self): inverse_node = deepcopy(self) inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl() return inverse_node def shape_propagation(self, input_node_metas): if self._tensor is not None: return assert len(input_node_metas) == 1, "Layout node can only have one input node" output_shape = self.underlying_impl.shape_propagation(input_node_metas[0]) self._tensor = Tensor( element=self.element_output, shape=output_shape, layout_tag=LayoutType.RowMajor ) return super().shape_propagation(input_node_metas) def type_propagation(self, input_node_metas: 'list[NodeBase]'): """ The store nodes has element_output = element_input """ assert len(input_node_metas) == 1, "Layout node can only have one input node" self.element_output = input_node_metas[0].element_output def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): """ Propagate the broadcast in the reversed topological order """ if self.tensor is None: raise RuntimeError(f"The tensor of node {self.name} is unknown.") shape = self.tensor.shape for child in input_node_metas: self.underlying_impl.broadcast(shape, child) def apply_to_user(self, usr_meta: NodeBase): """ Propagate the permutation to user nodes """ self.underlying_impl.apply_to_user(usr_meta) def apply_to_input(self, input_meta: NodeBase): """ Propagate the permutation to input nodes """ self.underlying_impl.apply_to_input(input_meta)