Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
committed by
Haicheng Wu
parent
4260d4aef9
commit
177a82e251
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
@ -0,0 +1,306 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from re import sub
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class TupleEmitter:
|
||||
"""
|
||||
Emit the cute tuple to C++ code
|
||||
"""
|
||||
def __init__(self, stride_dtype):
|
||||
self.stride_dtype = stride_dtype
|
||||
|
||||
def emit(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple in [0, 1]:
|
||||
return f"cute::Int<{py_tuple}>"
|
||||
else:
|
||||
return f"{self.stride_dtype}"
|
||||
elif isinstance(py_tuple, tuple):
|
||||
decl = "cute::Stride<"
|
||||
for item in py_tuple:
|
||||
decl += self.emit(item) + ", "
|
||||
return decl[:-2] + ">"
|
||||
else:
|
||||
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
|
||||
|
||||
|
||||
class ImplBase:
|
||||
"""
|
||||
Base class for Node Implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.tensor = node.tensor
|
||||
self._type_decl = None
|
||||
self.tuple_emitter = TupleEmitter("int64_t")
|
||||
|
||||
@property
|
||||
def stride_dtype(self):
|
||||
return self.tuple_emitter.stride_dtype
|
||||
|
||||
@stride_dtype.setter
|
||||
def stride_dtype(self, stride_dtype):
|
||||
self.tuple_emitter.stride_dtype = stride_dtype
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
"""
|
||||
Match function used in get_underlying_impl
|
||||
"""
|
||||
raise NotImplementedError(f"The `match` function is not defined.")
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
"""
|
||||
Default class for Argument Type
|
||||
"""
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return _Argument
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
||||
|
||||
@property
|
||||
def stride_mnl(self):
|
||||
"""
|
||||
Typename StrideMNL
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return self.tuple_emitter.emit(stride)
|
||||
|
||||
def get_non_constant_stride(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple not in [0, 1]:
|
||||
return py_tuple
|
||||
else:
|
||||
return None
|
||||
non_constant_stride = []
|
||||
for item in py_tuple:
|
||||
item_out = self.get_non_constant_stride(item)
|
||||
if item_out:
|
||||
non_constant_stride.append(item_out)
|
||||
return tuple(non_constant_stride)
|
||||
|
||||
def get_stride_mnl(self):
|
||||
"""
|
||||
Get the non-zero stride mnl. This is used in argument construction
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return stride
|
||||
|
||||
def get_smem_size(self, *args, **kwargs):
|
||||
"""
|
||||
Get the shared memory size and alignment of current node
|
||||
"""
|
||||
return (0, 1)
|
||||
|
||||
|
||||
class NoOpImpl(ImplBase):
|
||||
"""
|
||||
The NoOpImpl does nothing but forward its input to users
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.op == "store":
|
||||
# Store that is not output is a No OP
|
||||
return not node.is_output
|
||||
|
||||
|
||||
class NodeBase:
|
||||
"""
|
||||
Base class of DAG Node
|
||||
"""
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.underlying_impl = None
|
||||
|
||||
self._tensor = None
|
||||
|
||||
# Whether the node is disabled for emit
|
||||
self.disabled = False
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return self.underlying_impl.name_camel
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
@tensor.setter
|
||||
def tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._tensor = Tensor(**kwargs)
|
||||
|
||||
#
|
||||
# Helper functions for type/shape propagation
|
||||
#
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
"""
|
||||
Infer shape from input nodes
|
||||
General Broadcasting Rules from NumPy
|
||||
When operating on two arrays, we compare their shapes element-wise.
|
||||
It starts with the trailing (i.e. rightmost) dimension and works its
|
||||
way left. Two dimensions are compatible when
|
||||
1. they are equal
|
||||
2. one of them is 1
|
||||
"""
|
||||
if self._tensor is not None:
|
||||
return
|
||||
|
||||
shape = None
|
||||
for src in input_node_metas:
|
||||
src_shape = src.tensor.shape
|
||||
if shape is None:
|
||||
shape = src_shape
|
||||
else:
|
||||
len_difference = len(shape) - len(src_shape)
|
||||
if len_difference > 0:
|
||||
for _ in range(len_difference):
|
||||
src_shape = [1, ] + list(src_shape)
|
||||
elif len_difference < 0:
|
||||
for _ in range(-len_difference):
|
||||
shape = [1, ] + list(shape)
|
||||
broadcasted_shape = []
|
||||
# Infer broadcast shape
|
||||
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
||||
if shape_dim == 1:
|
||||
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
||||
elif src_dim == 1:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
elif shape_dim == src_dim:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
else:
|
||||
error_msg = "Dimension mismatch between "
|
||||
for src_ in input_node_metas:
|
||||
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
||||
error_msg = error_msg[:-2] + "."
|
||||
raise RuntimeError(error_msg)
|
||||
shape = tuple(broadcasted_shape)
|
||||
|
||||
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Each node is associated with two data types: `element` and `element_output`.
|
||||
The `element_output` is the type of return array of the node. The `element`
|
||||
has specific meaning for different node types.
|
||||
* Load Node: data type of tensor in gmem
|
||||
* Compute Node: element compute
|
||||
* Store Node: data type of tensor in gmem
|
||||
This function must be overloaded in the derived classes
|
||||
"""
|
||||
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order.
|
||||
For example:
|
||||
C[l, m, n] = A[m, 1] + B[l, m, n]
|
||||
After the broadcast propagation, it will be come
|
||||
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
||||
and each tensor will have a proper stride accessing the underlying tensor
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
for child in input_node_metas:
|
||||
child.tensor.broadcast(self.tensor.shape)
|
||||
|
||||
def get_underlying_impl(self, problem_size: tuple):
|
||||
"""
|
||||
Get the underlying implementation of the current node.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
||||
|
||||
for impl in self.possible_impls:
|
||||
if impl.match(self, problem_size):
|
||||
self.underlying_impl = impl(self)
|
||||
break
|
||||
|
||||
if self.underlying_impl is None:
|
||||
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
||||
|
||||
#
|
||||
# Visitor Nodes & Impls
|
||||
#
|
||||
|
||||
class TopoVisitorImpl(ImplBase):
|
||||
"""
|
||||
Impl for topological visitor
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node.output_node)
|
||||
self.name = node.name
|
||||
self.element_output = node.output_node.element_output
|
||||
|
||||
class TopoVisitorNode(NodeBase):
|
||||
def __init__(self, name: str, subgraph, output_node) -> None:
|
||||
super().__init__(name)
|
||||
self.subgraph = subgraph
|
||||
self.output_node = output_node
|
||||
self.op = "dag"
|
||||
self.underlying_impl = TopoVisitorImpl(self)
|
||||
Reference in New Issue
Block a user