325 lines
13 KiB
Python
325 lines
13 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
Layout algebras
|
|
"""
|
|
|
|
from pycute import Layout, composition, make_layout, flatten, product
|
|
|
|
|
|
def _infer_split(old_shape, new_shape):
|
|
old_shape = _tuple_to_list(old_shape)
|
|
new_shape = _tuple_to_list(new_shape)
|
|
if len(old_shape) == 0 and len(new_shape) == 0:
|
|
return []
|
|
if len(old_shape) == 0:
|
|
if product(tuple(new_shape)) != 1:
|
|
raise ValueError("Invalid reshape size")
|
|
else:
|
|
return new_shape
|
|
if len(new_shape) == 0:
|
|
if product(tuple(old_shape)) != 1:
|
|
raise ValueError("Invalid reshape size")
|
|
else:
|
|
return old_shape
|
|
# This is done recursively by only process the last dimension at each time
|
|
old_dim = old_shape[-1]
|
|
new_dim = new_shape[-1]
|
|
# Exact match
|
|
if old_dim == new_dim:
|
|
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
|
|
# Needs split
|
|
if old_dim > new_dim and old_dim % new_dim == 0:
|
|
residual = old_dim // new_dim
|
|
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
|
|
# Needs merge
|
|
if old_dim < new_dim and new_dim % old_dim == 0:
|
|
residual = new_dim // old_dim
|
|
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
|
|
|
|
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
|
|
|
def _infer_merge(flatten_shape, shape):
|
|
flatten_shape = _tuple_to_list(flatten_shape)
|
|
shape = _tuple_to_list(shape)
|
|
idx_flat = 0
|
|
merged_shape = []
|
|
for dim in shape:
|
|
# Exact match
|
|
if dim == flatten_shape[idx_flat]:
|
|
merged_shape.append(dim)
|
|
idx_flat += 1
|
|
# Need group
|
|
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
|
residual = dim
|
|
group = []
|
|
while(residual > 1):
|
|
group.append(flatten_shape[idx_flat])
|
|
residual = residual // flatten_shape[idx_flat]
|
|
idx_flat += 1
|
|
merged_shape.append(group)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
|
|
|
return merged_shape
|
|
|
|
def _list_to_tuple(nested_list):
|
|
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
|
return tuple(_list_to_tuple(item) for item in nested_list)
|
|
return nested_list
|
|
|
|
def _tuple_to_list(nested_tuple):
|
|
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
|
return list(_tuple_to_list(item) for item in nested_tuple)
|
|
return nested_tuple
|
|
|
|
def _reverse_tuple(nested_tuple: tuple):
|
|
if isinstance(nested_tuple, tuple):
|
|
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
|
return nested_tuple
|
|
|
|
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
|
for i in reversed(range(idx)):
|
|
if stride_list[i] != 0:
|
|
return i
|
|
else:
|
|
return None
|
|
|
|
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
|
for i in range(idx+1, len(stride_list)):
|
|
if stride_list[i] != 0:
|
|
return i
|
|
else:
|
|
return None
|
|
|
|
def reshape(layout, new_shape):
|
|
"""
|
|
General reshape of input layout.
|
|
It takes two steps:
|
|
1. split the dimensions of the old layout
|
|
2. merge the splitted dimensions according to the new shape
|
|
"""
|
|
#
|
|
# Step 1: Split the dimensions of the old layout
|
|
#
|
|
# 1.1 Flat old and new shape
|
|
old_flatten_shape = list(flatten(layout.shape))
|
|
new_flatten_shape = list(flatten(new_shape))
|
|
|
|
# 1.2 Infer the flatten splitted shape
|
|
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
|
|
|
# 1.3 Unflat the splitted shape based on the old shape
|
|
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
|
|
|
# 1.4 Infer the type of each split
|
|
# If the split type is in row-major (R), the dimension list is reversed because
|
|
# the cute::composition only support column-major split
|
|
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
|
permuted_splitted_shape = []
|
|
old_flatten_stride = list(flatten(layout.stride))
|
|
for idx, dim in enumerate(splited_shape):
|
|
if not isinstance(dim, list):
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
else:
|
|
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
|
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
|
# Special case for single tuple
|
|
# Use column-major by default
|
|
if lhs_stride is None and rhs_stride is None:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
else:
|
|
if lhs_stride is not None and rhs_stride is not None:
|
|
# We consider shape[idx]:stride[idx]
|
|
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
|
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
|
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
|
permuted_splitted_shape.append([d for d in reversed(dim)])
|
|
split_type.append("R")
|
|
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
|
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
|
if lhs_stride >= rhs_stride:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
else:
|
|
permuted_splitted_shape.append([d for d in reversed(dim)])
|
|
split_type.append("R")
|
|
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
|
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
|
if lhs_stride >= rhs_stride:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
else:
|
|
permuted_splitted_shape.append([d for d in reversed(dim)])
|
|
split_type.append("R")
|
|
else:
|
|
raise NotImplementedError()
|
|
elif lhs_stride is None:
|
|
# Case 1: dim's stride < dim+1's stride, expand in column major
|
|
if old_flatten_stride[idx] > rhs_stride:
|
|
permuted_splitted_shape.append([d for d in reversed(dim)])
|
|
split_type.append("R")
|
|
else:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
else:
|
|
# Case 1: dim's stride > dim-1's stride
|
|
if old_flatten_stride[idx] < lhs_stride:
|
|
permuted_splitted_shape.append([d for d in reversed(dim)])
|
|
split_type.append("R")
|
|
else:
|
|
permuted_splitted_shape.append(dim)
|
|
split_type.append("C")
|
|
|
|
# 1.4 Generate the splitted layout
|
|
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
|
|
|
|
# 1.5 Reverse the permutation in 1.4 before merge
|
|
splitted_shape = []
|
|
splitted_stride = []
|
|
for shape_dim, stride_dim, type in zip(
|
|
permuted_splitted_layout.shape,
|
|
permuted_splitted_layout.stride,
|
|
split_type):
|
|
if type == "C":
|
|
splitted_shape.append(shape_dim)
|
|
splitted_stride.append(stride_dim)
|
|
else:
|
|
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
|
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
|
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
|
|
|
|
|
#
|
|
# Step 2: Merge the splitted dimensions according to the new shape
|
|
#
|
|
# 2.1 Merge layout
|
|
merged_layout = composition(splitted_layout, Layout(new_shape))
|
|
|
|
# 2.2 Cleaning up
|
|
output_layout = composition(merged_layout, Layout(new_shape))
|
|
return output_layout
|
|
|
|
|
|
def permutation(layout, permutation):
|
|
"""
|
|
Permute the layout
|
|
"""
|
|
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
|
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
|
return Layout(new_shape, new_stride)
|
|
|
|
|
|
def _broadcast(layout, new_shape):
|
|
if len(layout) == 1 and isinstance(new_shape, int):
|
|
old_dim = layout.shape
|
|
old_stride = layout.stride
|
|
new_dim = new_shape
|
|
if old_dim == new_dim:
|
|
return Layout(old_dim, old_stride)
|
|
elif old_dim == 1:
|
|
return Layout(new_dim, 0)
|
|
else:
|
|
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
|
|
|
# Align the dimensions
|
|
old_shape = layout.shape
|
|
if isinstance(old_shape, int):
|
|
old_shape = (old_shape,)
|
|
sub_layouts = [layout,]
|
|
else:
|
|
sub_layouts = [sub_layout for sub_layout in layout]
|
|
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
|
# Get the broadcasted layout
|
|
broadcast_layouts = []
|
|
try:
|
|
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
|
broadcast_layouts = []
|
|
for idx, sub_layout in enumerate(layout):
|
|
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
|
except NotImplementedError:
|
|
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
|
for idx, sub_layout in enumerate(layout):
|
|
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
|
return make_layout(*broadcast_layouts)
|
|
|
|
|
|
def broadcast(layout, new_shape):
|
|
"""
|
|
Broadcast the new layout based on the input shape
|
|
The broadcasted shape equals to the new shape
|
|
The stride of broadcasted dimensions are 0
|
|
"""
|
|
return _broadcast(layout, new_shape)
|
|
|
|
|
|
def debroadcast(layout, dims):
|
|
"""
|
|
Squeeze the 0-stride
|
|
"""
|
|
for dim in dims:
|
|
if layout.stride[dim] != 0:
|
|
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
|
|
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
|
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
|
return Layout(new_shape, new_stride)
|
|
|
|
|
|
def canonicalization_(shapes, strides):
|
|
if isinstance(shapes, tuple):
|
|
c_shapes = []
|
|
c_strides = []
|
|
for shape, stride in zip(shapes, strides):
|
|
c_shape, c_stride = canonicalization_(shape, stride)
|
|
c_shapes.append(c_shape)
|
|
c_strides.append(c_stride)
|
|
return tuple(c_shapes), tuple(c_strides)
|
|
else:
|
|
if shapes == 1:
|
|
return 1, 0
|
|
else:
|
|
return shapes, strides
|
|
|
|
def canonicalization(layout):
|
|
"""
|
|
Canonicalize the input layout
|
|
1. set the stride of shape "1" to 0
|
|
"""
|
|
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
|
return Layout(new_shape, new_stride)
|