138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
High-level class for tensor
|
|
"""
|
|
|
|
from cutlass_library import LayoutType
|
|
|
|
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
|
Layout,
|
|
broadcast,
|
|
canonicalization,
|
|
permutation,
|
|
reshape,
|
|
_reverse_tuple
|
|
)
|
|
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
|
|
|
|
|
class Tensor:
|
|
"""
|
|
The tensor abstracts the data type
|
|
"""
|
|
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
|
if element is not None and tensor is not None:
|
|
raise Exception(f"Must not specify both element and tensor")
|
|
elif shape is not None and tensor is not None:
|
|
raise Exception(f"Must not specify both shape and tensor")
|
|
elif layout_tag is not None and tensor is not None:
|
|
raise Exception(f"Must not specify both layout_tag and tensor")
|
|
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
|
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
|
elif stride is not None and tensor is not None:
|
|
raise Exception(f"Must not specify both stride and tensor")
|
|
elif stride is not None and layout_tag is not None:
|
|
raise Exception(f"Must not specify layout_tag when stride is provided")
|
|
|
|
if isinstance(tensor, Tensor):
|
|
# Directly copy all the attributes
|
|
self.__dict__.update(vars(tensor))
|
|
else:
|
|
if tensor is None:
|
|
self.element = library_type(element)
|
|
else:
|
|
self.element, layout_tag = get_datatype_and_layout(tensor)
|
|
shape = get_tensor_shape(tensor)
|
|
if stride is not None:
|
|
self.layout = Layout(shape[::-1], stride[::-1])
|
|
else:
|
|
if layout_tag == LayoutType.RowMajor:
|
|
self.layout = Layout(shape[::-1])
|
|
elif layout_tag == LayoutType.ColumnMajor:
|
|
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
|
self.layout = canonicalization(self.layout)
|
|
|
|
self.is_constant = is_constant
|
|
# Save the tensor value if it is constant
|
|
if is_constant and tensor is not None:
|
|
self.value = tensor
|
|
|
|
@property
|
|
def shape(self):
|
|
"""
|
|
Returns the RowMajor layout shape
|
|
"""
|
|
return _reverse_tuple(self.layout.shape)
|
|
|
|
@property
|
|
def stride(self):
|
|
"""
|
|
Returns the RowMajor layout stride
|
|
"""
|
|
return _reverse_tuple(self.layout.stride)
|
|
|
|
@property
|
|
def rank(self):
|
|
"""
|
|
Returns the rank of the tensor
|
|
"""
|
|
return len(self.shape)
|
|
|
|
#
|
|
# Layout Algorithms
|
|
#
|
|
|
|
def broadcast(self, shape):
|
|
"""
|
|
Broadcast self.layout to shape
|
|
"""
|
|
assert isinstance(shape, tuple)
|
|
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
|
|
|
def reshape(self, shape):
|
|
"""
|
|
Reshape self.layout to shape
|
|
"""
|
|
assert isinstance(shape, tuple)
|
|
reverse_shape = _reverse_tuple(shape)
|
|
self.layout = reshape(self.layout, reverse_shape)
|
|
|
|
def permute(self, indices):
|
|
"""
|
|
Permute self.layout according to indices
|
|
"""
|
|
length = len(indices)
|
|
indices = [length - idx - 1 for idx in indices]
|
|
self.layout = permutation(self.layout, indices[::-1])
|