Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
@ -0,0 +1,137 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
"""
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
permutation,
|
||||
reshape,
|
||||
_reverse_tuple
|
||||
)
|
||||
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception(f"Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
self.__dict__.update(vars(tensor))
|
||||
else:
|
||||
if tensor is None:
|
||||
self.element = library_type(element)
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
# Save the tensor value if it is constant
|
||||
if is_constant and tensor is not None:
|
||||
self.value = tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the RowMajor layout shape
|
||||
"""
|
||||
return _reverse_tuple(self.layout.shape)
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
Returns the RowMajor layout stride
|
||||
"""
|
||||
return _reverse_tuple(self.layout.stride)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Returns the rank of the tensor
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
#
|
||||
# Layout Algorithms
|
||||
#
|
||||
|
||||
def broadcast(self, shape):
|
||||
"""
|
||||
Broadcast self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
||||
|
||||
def reshape(self, shape):
|
||||
"""
|
||||
Reshape self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
reverse_shape = _reverse_tuple(shape)
|
||||
self.layout = reshape(self.layout, reverse_shape)
|
||||
|
||||
def permute(self, indices):
|
||||
"""
|
||||
Permute self.layout according to indices
|
||||
"""
|
||||
length = len(indices)
|
||||
indices = [length - idx - 1 for idx in indices]
|
||||
self.layout = permutation(self.layout, indices[::-1])
|
||||
Reference in New Issue
Block a user