Rename python/cutlass to python/cutlass_cppgen (#2652)

This commit is contained in:
Jack Kosaian
2025-09-18 13:26:57 -05:00
committed by GitHub
parent 74825181f2
commit b234a8c024
71 changed files with 1 additions and 1 deletions

View File

@ -0,0 +1,53 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
from cutlass_cppgen.backend.evt.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl
)
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass_cppgen.backend.evt.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl
)

View File

@ -0,0 +1,91 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python registration for compute nodes in EVT
"""
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
from cutlass_cppgen.backend.library import FloatRoundStyle
class ComputeImplBase(ImplBase):
"""
Base class for compute implementation
"""
def __init__(self, node) -> None:
super().__init__(node)
class ComputeImpl(ComputeImplBase):
"""
Implementation for Compute Node
"""
def __init__(self, node) -> None:
super().__init__(node)
self.fn = node.fn
self.element_output = node.element_output
self.element_compute = node.element_compute
self.round_style = node.round_style
@staticmethod
def match(node, problem_size: tuple):
return True
class ComputeNode(NodeBase):
"""
Compute Node in DAG IR
"""
possible_impls = [
ComputeImpl
]
def __init__(
self, name: str, fn, element_output,
element_compute,
round_style=FloatRoundStyle.ToNearest) -> None:
super().__init__(name)
self.op = "compute"
self.fn = fn
self.element_compute = element_compute
self.round_style = round_style
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
self.element = self.element_compute
# In general, the compute nodes have element_output = element_compute
# In certain cases like producer of D it is overwritten by other passes
if not hasattr(self, "element_output"):
self.element_output = self.element

View File

@ -0,0 +1,254 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT
"""
import networkx as nx
from cutlass_library import DataType
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.library import ActivationOp
from cutlass_cppgen.backend.utils import device_cc
class DAGIR:
"""
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, cc, element_compute=DataType.f32) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
self.element_compute = element_compute
self.reduction_names = []
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
def add_node(self, meta: NodeBase):
"""
Add a node to dag ir
"""
if self.has_node(meta.name):
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
self._graph.add_node(meta.name, meta=meta)
def add_edge(self, src: str, dst: str, weight: int=0):
"""
Add an edge src -> dst to dag ir with weight
"""
if not self.has_node(src):
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name, fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""
Remove node from dag ir
"""
self._graph.remove_node(node)
def remove_edge(self, src: str, dst: str):
"""
Remove edge src -> dst
"""
self._graph.remove_edge(src, dst)
#
# Helper functions for getting attrs
#
def has_node(self, node: str) -> bool:
"""
Check if the node is in the graph
"""
return self._graph.has_node(node)
def in_degree(self, node: str):
"""
Get the input degree of node
"""
return self._graph.in_degree(node)
def in_edges(self, node: str):
"""
Get the input edges of node
"""
return [edge for edge in self._graph.in_edges(node)]
def out_degree(self, node: str):
"""
Get the output degree of node
"""
return self._graph.out_degree(node)
def out_edges(self, node: str):
"""
Get the output edges of node
"""
return [edge for edge in self._graph.out_edges(node)]
def get_node_meta(self, node: str):
"""
Get the meta data of the node
"""
return self._graph.nodes[node]["meta"]
def get_edge_weight(self, src, dst):
"""
Get the edge weight of edge src->dst
"""
return self._graph.get_edge_data(src, dst)["weight"]
#
# High-level helper functions
#
def all_reachable_nodes(self, node: str):
"""
Get all the nodes reachable from the current node (exclude)
"""
return list(nx.dfs_preorder_nodes(self._graph, source=node))
def get_users(self, node: str):
"""
Get all users of the current node
"""
return [edge[1] for edge in self.out_edges(node)]
def get_all_inputs(self, node: str):
"""
Get all the input nodes sorted by edge weight
"""
in_edges = self.in_edges(node)
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
def get_all_inputs_meta(self, node: str):
"""
Get all the input node metas sorted by edge weight
"""
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
def replace_all_uses_with(self, node1, node2):
"""
Replace all uses of node1 with node2
"""
for edge in self.out_edges(node1):
weight = self.get_edge_weight(*edge)
user = edge[1]
self.add_edge(node2, user, weight)
self.remove_edge(node1, user)
self.remove_node(node1)
#
# Node accessor
#
def nodes_topological_order(self):
"""
Get the nodes in the unique lexicographical topological order
It generates a unique ordering of nodes by first sorting topologically
and then additionally by sorting lexicographically.
Although topological_sort alone also works, this generates a unique key
for each epilogue visitor pattern and ensures the compilation cache can be reused.
:return: list[str]
"""
return list(nx.lexicographical_topological_sort(self._graph))
def node_metas_topological_order(self):
"""
Get the node metas in topological order
:return: list[NodeBase]
"""
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
@property
def nodes(self):
"""
Get all nodes
:return: list[str]
"""
return list(self._graph.nodes)
@property
def nodes_meta(self):
"""
Get all node metas
:return: list[NodeBase]
"""
return [data[1]['meta'] for data in self._graph.nodes.data()]
@property
def edges(self):
"""
Get all edges
:return: list[(str, str)]
"""
return list(self._graph.edges)
#
# Path
#
def has_path(self, src: str, target: str) -> bool:
"""
Return True is a path exists from src to target
"""
return nx.has_path(self._graph, src, target)

View File

@ -0,0 +1,324 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout algebras
"""
from pycute import Layout, composition, make_layout, flatten, product
def _infer_split(old_shape, new_shape):
old_shape = _tuple_to_list(old_shape)
new_shape = _tuple_to_list(new_shape)
if len(old_shape) == 0 and len(new_shape) == 0:
return []
if len(old_shape) == 0:
if product(tuple(new_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return new_shape
if len(new_shape) == 0:
if product(tuple(old_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return old_shape
# This is done recursively by only process the last dimension at each time
old_dim = old_shape[-1]
new_dim = new_shape[-1]
# Exact match
if old_dim == new_dim:
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
def _infer_merge(flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = 0
merged_shape = []
for dim in shape:
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat += 1
# Need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat += 1
merged_shape.append(group)
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape
def _list_to_tuple(nested_list):
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
return tuple(_list_to_tuple(item) for item in nested_list)
return nested_list
def _tuple_to_list(nested_tuple):
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
return list(_tuple_to_list(item) for item in nested_tuple)
return nested_tuple
def _reverse_tuple(nested_tuple: tuple):
if isinstance(nested_tuple, tuple):
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
return nested_tuple
def _get_first_lhs_nonzero_stride(stride_list, idx):
for i in reversed(range(idx)):
if stride_list[i] != 0:
return i
else:
return None
def _get_first_rhs_nonzero_stride(stride_list, idx):
for i in range(idx+1, len(stride_list)):
if stride_list[i] != 0:
return i
else:
return None
def reshape(layout, new_shape):
"""
General reshape of input layout.
It takes two steps:
1. split the dimensions of the old layout
2. merge the splitted dimensions according to the new shape
"""
#
# Step 1: Split the dimensions of the old layout
#
# 1.1 Flat old and new shape
old_flatten_shape = list(flatten(layout.shape))
new_flatten_shape = list(flatten(new_shape))
# 1.2 Infer the flatten splitted shape
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
# 1.3 Unflat the splitted shape based on the old shape
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
# 1.4 Infer the type of each split
# If the split type is in row-major (R), the dimension list is reversed because
# the cute::composition only support column-major split
split_type = [] # the type of each split (ColumnMajor or RowMajor)
permuted_splitted_shape = []
old_flatten_stride = list(flatten(layout.stride))
for idx, dim in enumerate(splited_shape):
if not isinstance(dim, list):
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
# Special case for single tuple
# Use column-major by default
if lhs_stride is None and rhs_stride is None:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
if lhs_stride is not None and rhs_stride is not None:
# We consider shape[idx]:stride[idx]
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
raise NotImplementedError()
elif lhs_stride is None:
# Case 1: dim's stride < dim+1's stride, expand in column major
if old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
# Case 1: dim's stride > dim-1's stride
if old_flatten_stride[idx] < lhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
# 1.4 Generate the splitted layout
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
# 1.5 Reverse the permutation in 1.4 before merge
splitted_shape = []
splitted_stride = []
for shape_dim, stride_dim, type in zip(
permuted_splitted_layout.shape,
permuted_splitted_layout.stride,
split_type):
if type == "C":
splitted_shape.append(shape_dim)
splitted_stride.append(stride_dim)
else:
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
#
# Step 2: Merge the splitted dimensions according to the new shape
#
# 2.1 Merge layout
merged_layout = composition(splitted_layout, Layout(new_shape))
# 2.2 Cleaning up
output_layout = composition(merged_layout, Layout(new_shape))
return output_layout
def permutation(layout, permutation):
"""
Permute the layout
"""
new_shape = tuple([layout.shape[idx] for idx in permutation])
new_stride = tuple([layout.stride[idx] for idx in permutation])
return Layout(new_shape, new_stride)
def _broadcast(layout, new_shape):
if len(layout) == 1 and isinstance(new_shape, int):
old_dim = layout.shape
old_stride = layout.stride
new_dim = new_shape
if old_dim == new_dim:
return Layout(old_dim, old_stride)
elif old_dim == 1:
return Layout(new_dim, 0)
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
# Align the dimensions
old_shape = layout.shape
if isinstance(old_shape, int):
old_shape = (old_shape,)
sub_layouts = [layout,]
else:
sub_layouts = [sub_layout for sub_layout in layout]
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
# Get the broadcasted layout
broadcast_layouts = []
try:
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
broadcast_layouts = []
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
except NotImplementedError:
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
return make_layout(*broadcast_layouts)
def broadcast(layout, new_shape):
"""
Broadcast the new layout based on the input shape
The broadcasted shape equals to the new shape
The stride of broadcasted dimensions are 0
"""
return _broadcast(layout, new_shape)
def debroadcast(layout, dims):
"""
Squeeze the 0-stride
"""
for dim in dims:
if layout.stride[dim] != 0:
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
return Layout(new_shape, new_stride)
def canonicalization_(shapes, strides):
if isinstance(shapes, tuple):
c_shapes = []
c_strides = []
for shape, stride in zip(shapes, strides):
c_shape, c_stride = canonicalization_(shape, stride)
c_shapes.append(c_shape)
c_strides.append(c_stride)
return tuple(c_shapes), tuple(c_strides)
else:
if shapes == 1:
return 1, 0
else:
return shapes, strides
def canonicalization(layout):
"""
Canonicalize the input layout
1. set the stride of shape "1" to 0
"""
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
return Layout(new_shape, new_stride)

View File

@ -0,0 +1,336 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""
from copy import deepcopy
from cutlass_library import LayoutType
from pycute import product, flatten
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class PermutationImpl:
"""
Detailed implementation and helper functions for permutation
"""
def __init__(self, node) -> None:
assert "indices" in node.kwargs.keys()
self.indices = list(node.kwargs["indices"])
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.indices = self.inverse_indices
inverse_impl.inverse_indices = self.indices
return inverse_impl
def update(self, shape):
num_dim = len(shape)
indices = self.indices
num_old_dim = len(indices)
# Add offset
for i, idx in enumerate(indices):
indices[i] = idx + num_dim - num_old_dim
# Add broadcast dims
for i in range(num_dim - num_old_dim):
indices = [i,] + indices
self.indices = indices
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_indices(self, indices):
"""
Get the indices for inverse permutation
"""
num_dim = len(indices)
inverse_indices = [0] * num_dim
for i in range(num_dim):
inverse_indices[indices[i]] = i
return inverse_indices
def shape_propagation(self, input_node_meta):
input_shape = input_node_meta.tensor.shape
output_shape = tuple([input_shape[idx] for idx in self.indices])
return output_shape
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape
"""
self.update(shape)
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
node_meta.tensor.broadcast(inverse_shape)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to the users of the current nodes
"""
usr_meta.tensor.permute(self.inverse_indices)
if hasattr(usr_meta, "store_tensor"):
if usr_meta.store_tensor is not None:
usr_meta.store_tensor.permute(self.inverse_indices)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to inputs of the current nodes
"""
input_meta.tensor.permute(self.indices)
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.permute(self.indices)
class ReshapeImpl:
"""
Detailed implementation and helper functions for reshape
"""
def __init__(self, node) -> None:
self.node = node
assert "new_shape" in node.kwargs.keys()
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.output_shape = self.input_shape
inverse_impl.input_shape = self.output_shape
return inverse_impl
def shape_propagation(self, input_node_meta):
self.input_shape = input_node_meta.tensor.shape
return _list_to_tuple(self.output_shape)
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape.
"""
# Step 1: infer split
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
# broadcast shape -> split_output_shape -> flatten_split_shape
if len(shape) - len(split_output_shape) > 0:
for _ in range(len(shape) - len(split_output_shape)):
split_output_shape = [1,] + split_output_shape
flatten_split_shape = [1,] + flatten_split_shape
split_input_shape = [1,] + split_input_shape
broadcast_factor = []
for dim, old_dim in zip(shape, split_output_shape):
if not isinstance(dim, list):
dim = [dim,]
if not isinstance(old_dim, list):
old_dim = [old_dim,]
if product(tuple(dim)) == product(tuple(old_dim)):
broadcast_factor += [1] * len(old_dim)
elif product(tuple(old_dim)) == 1:
assert len(dim) == 1
broadcast_factor.append(dim[0])
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
# flatten_split_shape -> split_input_shape
factor_idx = 0
broadcast_split_input_shape = []
for dim in split_input_shape:
if isinstance(dim, list):
new_dim = []
for d in dim:
new_dim.append(d * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape.append(new_dim)
else:
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
node_meta.tensor.broadcast(broadcast_split_input_shape)
# Last reshape op to clean up
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
node_meta.tensor.reshape(broadcast_input_shape)
# Update the input shape and output shape
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
self.output_shape = _list_to_tuple(shape)
def apply_to_user(self, user_meta: NodeBase):
"""
Propagate the reshape to user nodes
"""
user_meta.tensor.reshape(tuple(self.input_shape))
if hasattr(user_meta, "store_tensor"):
if user_meta.store_tensor is not None:
user_meta.store_tensor.reshape(tuple(self.input_shape))
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the reshape to input nodes
"""
input_meta.tensor.reshape(tuple(self.output_shape))
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.reshape(tuple(self.output_shape))
#
# Helper functions
#
def infer_split(self, input_shape, output_shape):
"""
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
"""
input_shape = _tuple_to_list(input_shape)
output_shape = _tuple_to_list(output_shape)
if len(input_shape) == 0 and len(output_shape) == 0:
return []
if len(input_shape) == 0:
if product(tuple(output_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return output_shape
if len(output_shape) == 0:
if product(tuple(input_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return input_shape
# This is done recursively by only process the last dimension at each time
old_dim = input_shape[-1]
new_dim = output_shape[-1]
# Exact match
if old_dim == new_dim:
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
def infer_merge(self, flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = len(flatten_shape) - 1
merged_shape = []
for dim in reversed(shape):
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat -= 1
# need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while(residual > 1):
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat -= 1
merged_shape.append(group[::-1])
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape[::-1]
class LayoutNode(NodeBase):
"""
Layout manipulation nodes
"""
fn_to_impl = {
"permute": PermutationImpl,
"reshape": ReshapeImpl
}
def __init__(self, name: str, fn, kwargs: dict) -> None:
super().__init__(name)
self.op = "layout"
self.fn = fn
self.kwargs = kwargs
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
def get_inverse_node(self):
inverse_node = deepcopy(self)
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
return inverse_node
def shape_propagation(self, input_node_metas):
if self._tensor is not None:
return
assert len(input_node_metas) == 1, "Layout node can only have one input node"
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
self._tensor = Tensor(
element=self.element_output,
shape=output_shape, layout_tag=LayoutType.RowMajor
)
return super().shape_propagation(input_node_metas)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
assert len(input_node_metas) == 1, "Layout node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
shape = self.tensor.shape
for child in input_node_metas:
self.underlying_impl.broadcast(shape, child)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to user nodes
"""
self.underlying_impl.apply_to_user(usr_meta)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to input nodes
"""
self.underlying_impl.apply_to_input(input_meta)

View File

@ -0,0 +1,294 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations
"""
import ctypes
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):
"""
Base class for load node implementations
"""
reserved_names = ["accum", "C"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.tensor.stride
class AccumulatorImpl(LoadImplBase):
"""
Accumulator node implementation
"""
@staticmethod
def match(node, problem_size: tuple):
return node.name == "accum" and node.tensor.shape == problem_size
class LoadSrcImpl(LoadImplBase):
"""
Load C implementation
"""
@property
def name_camel(self) -> str:
return "TensorC"
@property
def argument_type_c(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_C", ctypes.c_void_p),
("stride_C", tuple_type)
]
def __init__(self, ptr) -> None:
self.ptr_C = ptr
self.stride_C = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
return node.name == "C" and node.tensor.shape == problem_size
class AuxLoadImpl(LoadImplBase):
"""
Load arbitrary tensor
"""
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.null_default = to_ctype_value(0, element_type)
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class RowBroadcastImpl(LoadImplBase):
"""
Broadcast a row vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_row", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dRow", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_row = ptr
self.null_default = to_ctype_value(0, element_type)
self.dRow = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ColumnBroadcastImpl(LoadImplBase):
"""
Broadcast a column vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_col", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dCol", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_col = int(ptr)
self.null_default = to_ctype_value(0, element_type)
self.dCol = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class ScalarBroadcastImpl(LoadImplBase):
"""
Broadcast a scalar
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
if self.tensor.is_constant:
value = self.tensor.value
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
self.scalars = to_ctype_value(value, element_type)
self.scalar_ptrs = 0
self.dScalar = tuple_type(stride_mnl)
else:
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type)
]
def __init__(self, kwargs) -> None:
scalar_or_ptr = kwargs[name]
if isinstance(scalar_or_ptr, float):
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
self.scalar_ptrs = 0
else:
self.scalar_ptrs = int(scalar_or_ptr)
self.dScalar = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class LoadNode(NodeBase):
"""
Load Node
"""
cnt = 0
possible_impls = [
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
RowBroadcastImpl, ColumnBroadcastImpl,
ScalarBroadcastImpl
]
def __init__(self, name: str) -> None:
if name is None:
name = f"load{LoadNode.cnt}"
LoadNode.cnt += 1
super().__init__(name)
self.op = "load"
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
self.element = self.tensor.element
self.element_output = self.tensor.element

View File

@ -0,0 +1,306 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes
"""
import ctypes
from re import sub
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class TupleEmitter:
"""
Emit the cute tuple to C++ code
"""
def __init__(self, stride_dtype):
self.stride_dtype = stride_dtype
def emit(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple in [0, 1]:
return f"cute::Int<{py_tuple}>"
else:
return f"{self.stride_dtype}"
elif isinstance(py_tuple, tuple):
decl = "cute::Stride<"
for item in py_tuple:
decl += self.emit(item) + ", "
return decl[:-2] + ">"
else:
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
class ImplBase:
"""
Base class for Node Implementation
"""
def __init__(self, node) -> None:
self.node = node
self.name = node.name
self.tensor = node.tensor
self._type_decl = None
self.tuple_emitter = TupleEmitter("int64_t")
@property
def stride_dtype(self):
return self.tuple_emitter.stride_dtype
@stride_dtype.setter
def stride_dtype(self, stride_dtype):
self.tuple_emitter.stride_dtype = stride_dtype
@staticmethod
def match(node, problem_size: tuple):
"""
Match function used in get_underlying_impl
"""
raise NotImplementedError(f"The `match` function is not defined.")
@property
def argument_type(self):
"""
Default class for Argument Type
"""
class _Argument(ctypes.Structure):
_fields_ = []
def __init__(self, *args, **kwargs) -> None:
pass
return _Argument
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
@property
def stride_mnl(self):
"""
Typename StrideMNL
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return self.tuple_emitter.emit(stride)
def get_non_constant_stride(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple not in [0, 1]:
return py_tuple
else:
return None
non_constant_stride = []
for item in py_tuple:
item_out = self.get_non_constant_stride(item)
if item_out:
non_constant_stride.append(item_out)
return tuple(non_constant_stride)
def get_stride_mnl(self):
"""
Get the non-zero stride mnl. This is used in argument construction
"""
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
return stride
def get_smem_size(self, *args, **kwargs):
"""
Get the shared memory size and alignment of current node
"""
return (0, 1)
class NoOpImpl(ImplBase):
"""
The NoOpImpl does nothing but forward its input to users
"""
def __init__(self, node) -> None:
super().__init__(node)
@staticmethod
def match(node, problem_size: tuple):
if node.op == "store":
# Store that is not output is a No OP
return not node.is_output
class NodeBase:
"""
Base class of DAG Node
"""
def __init__(self, name: str) -> None:
self.name = name
self.underlying_impl = None
self._tensor = None
# Whether the node is disabled for emit
self.disabled = False
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return self.underlying_impl.name_camel
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._tensor
@tensor.setter
def tensor(self, kwargs):
"""
Setting the tensor
"""
self._tensor = Tensor(**kwargs)
#
# Helper functions for type/shape propagation
#
def shape_propagation(self, input_node_metas):
"""
Infer shape from input nodes
General Broadcasting Rules from NumPy
When operating on two arrays, we compare their shapes element-wise.
It starts with the trailing (i.e. rightmost) dimension and works its
way left. Two dimensions are compatible when
1. they are equal
2. one of them is 1
"""
if self._tensor is not None:
return
shape = None
for src in input_node_metas:
src_shape = src.tensor.shape
if shape is None:
shape = src_shape
else:
len_difference = len(shape) - len(src_shape)
if len_difference > 0:
for _ in range(len_difference):
src_shape = [1, ] + list(src_shape)
elif len_difference < 0:
for _ in range(-len_difference):
shape = [1, ] + list(shape)
broadcasted_shape = []
# Infer broadcast shape
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
if shape_dim == 1:
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
elif src_dim == 1:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
elif shape_dim == src_dim:
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
else:
error_msg = "Dimension mismatch between "
for src_ in input_node_metas:
error_msg += f"{src_.name}{src_.tensor.shape}, "
error_msg = error_msg[:-2] + "."
raise RuntimeError(error_msg)
shape = tuple(broadcasted_shape)
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
def type_propagation(self, *args, **kwargs):
"""
Each node is associated with two data types: `element` and `element_output`.
The `element_output` is the type of return array of the node. The `element`
has specific meaning for different node types.
* Load Node: data type of tensor in gmem
* Compute Node: element compute
* Store Node: data type of tensor in gmem
This function must be overloaded in the derived classes
"""
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
Propagate the broadcast in the reversed topological order.
For example:
C[l, m, n] = A[m, 1] + B[l, m, n]
After the broadcast propagation, it will be come
C[l, m, n] = A[l, m, n] + B[l, m, n]
and each tensor will have a proper stride accessing the underlying tensor
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
for child in input_node_metas:
child.tensor.broadcast(self.tensor.shape)
def get_underlying_impl(self, problem_size: tuple):
"""
Get the underlying implementation of the current node.
"""
if self.tensor is None:
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
for impl in self.possible_impls:
if impl.match(self, problem_size):
self.underlying_impl = impl(self)
break
if self.underlying_impl is None:
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
#
# Visitor Nodes & Impls
#
class TopoVisitorImpl(ImplBase):
"""
Impl for topological visitor
"""
def __init__(self, node) -> None:
super().__init__(node.output_node)
self.name = node.name
self.element_output = node.output_node.element_output
class TopoVisitorNode(NodeBase):
def __init__(self, name: str, subgraph, output_node) -> None:
super().__init__(name)
self.subgraph = subgraph
self.output_node = output_node
self.op = "dag"
self.underlying_impl = TopoVisitorImpl(self)

View File

@ -0,0 +1,277 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations
"""
import ctypes
from cutlass_library import DataType
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
class StoreImplBase(ImplBase):
"""
Base class for store node implementation
"""
reserved_names = ["D"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.store_tensor.stride
class StoreDImpl(StoreImplBase):
"""
Store D implementation
"""
@property
def argument_type_d(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_D", ctypes.c_void_p),
("stride_D", tuple_type)
]
def __init__(self, ptr: int) -> None:
self.ptr_D = ptr
self.stride_D = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name == "D" and node.store_tensor.shape == problem_size:
return True
return False
class AuxStoreImpl(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.round_style = FloatRoundStyle.ToNearest
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("dAux", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if (strideMN[0] == 1 and strideMN[1] != 0 or
strideMN[0] != 0 and strideMN[1] == 1 ):
return True
else:
return False
class ReductionImplBase(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.store_tensor.element
self.element_compute = node.element_compute
self.reg_reduce_fn = self.node.reg_reduce_fn
self.gmem_reduce_fn = self.node.gmem_reduce_fn
self.round_style = node.round_style
self.stride_dtype = "int"
def get_reduce_identity(self):
"""
Return the reduction identity of the current reduce_fn
"""
maxes = {
DataType.f32: (2 ** 31) - 1,
DataType.f16: (2 ** 15),
DataType.s32: (2 ** 31) - 1,
DataType.s8: (2 ** 7) - 1
}
mins = {
DataType.f32: -maxes[DataType.f32],
DataType.f16: -maxes[DataType.f16],
DataType.s32: -maxes[DataType.s32],
DataType.s8: -maxes[DataType.s8]
}
if self.reg_reduce_fn == FunctionalOp.Maximum:
if self.element_compute not in mins:
raise Exception(f"No min entry for data type {self.element_compute}")
return to_ctype_value(mins[self.element_compute], self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
return to_ctype_value(1., self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Minimum:
if self.element_compute not in maxes:
raise Exception(f"No max entry for data type {self.element_compute}")
return to_ctype_value(maxes[self.element_compute], self.element_compute)
else:
return to_ctype_value(0., self.element_compute)
@property
def argument_type(self):
self.get_reduce_identity()
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_compute = self.element_compute
reduce_identity = self.get_reduce_identity()
class _Argument(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("reduce_identity", dtype2ctype[element_compute]),
("dMNL", tuple_type)
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr = ptr
self.reduce_identity = reduce_identity
self.dMNL = tuple_type(stride_mnl)
return _Argument
class ColumnReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class RowReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ScalarReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class StoreNode(NodeBase):
"""
Store node
"""
possible_impls = [
AuxStoreImpl, RowReductionImpl,
ColumnReductionImpl, ScalarReductionImpl,
NoOpImpl, StoreDImpl
]
def __init__(self, name: str) -> None:
super().__init__(name)
self.op = "store"
self.is_output = False
self._store_tensor = None
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._store_tensor
@store_tensor.setter
def store_tensor(self, kwargs):
"""
Setting the tensor
"""
self._store_tensor = Tensor(**kwargs)
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
"""
The store nodes has element_output = element_input
"""
if self.is_output:
if self.store_tensor is None:
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
self.element = self.store_tensor.element
assert len(input_node_metas) == 1, "Store node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
super().broadcast_propagation(input_node_metas)
if self.is_output:
self._store_tensor.broadcast(self.tensor.shape)

View File

@ -0,0 +1,137 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass_library import LayoutType
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple
)
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
if element is not None and tensor is not None:
raise Exception(f"Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception(f"Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception(f"Must not specify both layout_tag and tensor")
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
elif stride is not None and tensor is not None:
raise Exception(f"Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception(f"Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])