Files
cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py
2025-09-18 14:26:57 -04:00

138 lines
5.2 KiB
Python

#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple
)
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])