################################################################################################# # # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Layout algebras """ from pycute import Layout, composition, make_layout, flatten, product def _infer_split(old_shape, new_shape): old_shape = _tuple_to_list(old_shape) new_shape = _tuple_to_list(new_shape) if len(old_shape) == 0 and len(new_shape) == 0: return [] if len(old_shape) == 0: if product(tuple(new_shape)) != 1: raise ValueError("Invalid reshape size") else: return new_shape if len(new_shape) == 0: if product(tuple(old_shape)) != 1: raise ValueError("Invalid reshape size") else: return old_shape # This is done recursively by only process the last dimension at each time old_dim = old_shape[-1] new_dim = new_shape[-1] # Exact match if old_dim == new_dim: return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,] # Needs split if old_dim > new_dim and old_dim % new_dim == 0: residual = old_dim // new_dim return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,] # Needs merge if old_dim < new_dim and new_dim % old_dim == 0: residual = new_dim // old_dim return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,] raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}") def _infer_merge(flatten_shape, shape): flatten_shape = _tuple_to_list(flatten_shape) shape = _tuple_to_list(shape) idx_flat = 0 merged_shape = [] for dim in shape: # Exact match if dim == flatten_shape[idx_flat]: merged_shape.append(dim) idx_flat += 1 # Need group elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: residual = dim group = [] while(residual > 1): group.append(flatten_shape[idx_flat]) residual = residual // flatten_shape[idx_flat] idx_flat += 1 merged_shape.append(group) else: raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") return merged_shape def _list_to_tuple(nested_list): if isinstance(nested_list, list) or isinstance(nested_list, tuple): return tuple(_list_to_tuple(item) for item in nested_list) return nested_list def _tuple_to_list(nested_tuple): if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple): return list(_tuple_to_list(item) for item in nested_tuple) return nested_tuple def _reverse_tuple(nested_tuple: tuple): if isinstance(nested_tuple, tuple): return tuple([_reverse_tuple(item) for item in nested_tuple][::-1]) return nested_tuple def _get_first_lhs_nonzero_stride(stride_list, idx): for i in reversed(range(idx)): if stride_list[i] != 0: return i else: return None def _get_first_rhs_nonzero_stride(stride_list, idx): for i in range(idx+1, len(stride_list)): if stride_list[i] != 0: return i else: return None def reshape(layout, new_shape): """ General reshape of input layout. It takes two steps: 1. split the dimensions of the old layout 2. merge the splitted dimensions according to the new shape """ # # Step 1: Split the dimensions of the old layout # # 1.1 Flat old and new shape old_flatten_shape = list(flatten(layout.shape)) new_flatten_shape = list(flatten(new_shape)) # 1.2 Infer the flatten splitted shape splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape) # 1.3 Unflat the splitted shape based on the old shape splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape) # 1.4 Infer the type of each split # If the split type is in row-major (R), the dimension list is reversed because # the cute::composition only support column-major split split_type = [] # the type of each split (ColumnMajor or RowMajor) permuted_splitted_shape = [] old_flatten_stride = list(flatten(layout.stride)) for idx, dim in enumerate(splited_shape): if not isinstance(dim, list): permuted_splitted_shape.append(dim) split_type.append("C") else: lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx) rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx) # Special case for single tuple # Use column-major by default if lhs_stride is None and rhs_stride is None: permuted_splitted_shape.append(dim) split_type.append("C") else: if lhs_stride is not None and rhs_stride is not None: # We consider shape[idx]:stride[idx] # Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: permuted_splitted_shape.append(dim) split_type.append("C") # Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: permuted_splitted_shape.append([d for d in reversed(dim)]) split_type.append("R") # Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: if lhs_stride >= rhs_stride: permuted_splitted_shape.append(dim) split_type.append("C") else: permuted_splitted_shape.append([d for d in reversed(dim)]) split_type.append("R") # Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: if lhs_stride >= rhs_stride: permuted_splitted_shape.append(dim) split_type.append("C") else: permuted_splitted_shape.append([d for d in reversed(dim)]) split_type.append("R") else: raise NotImplementedError() elif lhs_stride is None: # Case 1: dim's stride < dim+1's stride, expand in column major if old_flatten_stride[idx] > rhs_stride: permuted_splitted_shape.append([d for d in reversed(dim)]) split_type.append("R") else: permuted_splitted_shape.append(dim) split_type.append("C") else: # Case 1: dim's stride > dim-1's stride if old_flatten_stride[idx] < lhs_stride: permuted_splitted_shape.append([d for d in reversed(dim)]) split_type.append("R") else: permuted_splitted_shape.append(dim) split_type.append("C") # 1.4 Generate the splitted layout permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape))) # 1.5 Reverse the permutation in 1.4 before merge splitted_shape = [] splitted_stride = [] for shape_dim, stride_dim, type in zip( permuted_splitted_layout.shape, permuted_splitted_layout.stride, split_type): if type == "C": splitted_shape.append(shape_dim) splitted_stride.append(stride_dim) else: splitted_shape.append(tuple([d for d in reversed(shape_dim)])) splitted_stride.append(tuple([d for d in reversed(stride_dim)])) splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride)) # # Step 2: Merge the splitted dimensions according to the new shape # # 2.1 Merge layout merged_layout = composition(splitted_layout, Layout(new_shape)) # 2.2 Cleaning up output_layout = composition(merged_layout, Layout(new_shape)) return output_layout def permutation(layout, permutation): """ Permute the layout """ new_shape = tuple([layout.shape[idx] for idx in permutation]) new_stride = tuple([layout.stride[idx] for idx in permutation]) return Layout(new_shape, new_stride) def _broadcast(layout, new_shape): if len(layout) == 1 and isinstance(new_shape, int): old_dim = layout.shape old_stride = layout.stride new_dim = new_shape if old_dim == new_dim: return Layout(old_dim, old_stride) elif old_dim == 1: return Layout(new_dim, 0) else: raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}") # Align the dimensions old_shape = layout.shape if isinstance(old_shape, int): old_shape = (old_shape,) sub_layouts = [layout,] else: sub_layouts = [sub_layout for sub_layout in layout] rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape)) # Get the broadcasted layout broadcast_layouts = [] try: layout = make_layout(*sub_layouts, *rhs_broadcast_layouts) broadcast_layouts = [] for idx, sub_layout in enumerate(layout): broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) except NotImplementedError: layout = make_layout(*rhs_broadcast_layouts, *sub_layouts) for idx, sub_layout in enumerate(layout): broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) return make_layout(*broadcast_layouts) def broadcast(layout, new_shape): """ Broadcast the new layout based on the input shape The broadcasted shape equals to the new shape The stride of broadcasted dimensions are 0 """ return _broadcast(layout, new_shape) def debroadcast(layout, dims): """ Squeeze the 0-stride """ for dim in dims: if layout.stride[dim] != 0: raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}") new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims]) new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims]) return Layout(new_shape, new_stride) def canonicalization_(shapes, strides): if isinstance(shapes, tuple): c_shapes = [] c_strides = [] for shape, stride in zip(shapes, strides): c_shape, c_stride = canonicalization_(shape, stride) c_shapes.append(c_shape) c_strides.append(c_stride) return tuple(c_shapes), tuple(c_strides) else: if shapes == 1: return 1, 0 else: return shapes, strides def canonicalization(layout): """ Canonicalize the input layout 1. set the stride of shape "1" to 0 """ new_shape, new_stride = canonicalization_(layout.shape, layout.stride) return Layout(new_shape, new_stride)