273 lines
9.2 KiB
Python
273 lines
9.2 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
Base class for Python EVT Frontend
|
|
"""
|
|
|
|
from typing import Union
|
|
|
|
from cutlass_library import DataType
|
|
from cutlass_cppgen.backend.evt.ir import (
|
|
ComputeNode,
|
|
DAGIR,
|
|
LayoutNode,
|
|
LoadNode,
|
|
StoreNode,
|
|
)
|
|
from cutlass_cppgen.backend.evt.passes import (
|
|
EVTGraphDrawer,
|
|
EVTPassManager,
|
|
GetSmemSize,
|
|
PassDAG2Tree,
|
|
PassGetArgumentType,
|
|
PassGetImpl,
|
|
PassFixElementD,
|
|
PassLayoutManipulateElimination,
|
|
PassPreprocessRed,
|
|
PassShapeTypePropagation,
|
|
)
|
|
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
|
from cutlass_cppgen.backend.utils import device_cc
|
|
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
|
|
from cutlass_cppgen.utils.datatypes import library_type
|
|
|
|
|
|
class EVTFrontendBase:
|
|
layout_fns = {
|
|
"permute": permute,
|
|
"reshape": reshape
|
|
}
|
|
|
|
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
|
|
self.cc = cc
|
|
self.element_compute = library_type(element_compute)
|
|
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
|
self.compute_cnt = 0
|
|
self.layout_cnt = 0
|
|
self.imm_cnt = 0
|
|
|
|
self.pass_manager = EVTPassManager(
|
|
self.dag_ir,
|
|
[
|
|
PassPreprocessRed,
|
|
PassGetArgumentType,
|
|
PassShapeTypePropagation,
|
|
PassLayoutManipulateElimination,
|
|
PassGetImpl,
|
|
PassDAG2Tree,
|
|
PassFixElementD
|
|
] + additional_passes)
|
|
|
|
if self.cc == 80:
|
|
self._epilogue_stages = 1
|
|
else:
|
|
self._epilogue_stages = None
|
|
|
|
@property
|
|
def epilogue_stages(self):
|
|
return self._epilogue_stages
|
|
|
|
@epilogue_stages.setter
|
|
def epilogue_stages(self, stages):
|
|
self._epilogue_stages = stages
|
|
|
|
|
|
def parse(self, *args, **kwargs):
|
|
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
|
|
|
|
def trace(self, *args, **kwargs):
|
|
# Parse the input
|
|
self.parse(*args, **kwargs)
|
|
|
|
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
|
if (self.cc >= 90):
|
|
if (self.dag_ir.out_degree("D") != 0):
|
|
raise RuntimeError(
|
|
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
|
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
|
|
|
# Run the passes
|
|
self.pass_manager()
|
|
# Set the epilogue type
|
|
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
|
if cc_map[self.cc] in [90, 100]:
|
|
self.arg_c_type = self.dag_ir.arg_c_type
|
|
self.arg_d_type = self.dag_ir.arg_d_type
|
|
self.reduction_names = self.dag_ir.reduction_names
|
|
|
|
#
|
|
# Helper functions for DAG IR manipulation
|
|
#
|
|
|
|
def add_node(self, node):
|
|
self.dag_ir.add_node(node)
|
|
|
|
def add_edge(self, src, tgt, weight=0):
|
|
self.dag_ir.add_edge(src, tgt, weight=weight)
|
|
|
|
def set_tensor(self, node_name, example):
|
|
"""
|
|
Add an example tensor to node {node_name} in the DAG IR
|
|
"""
|
|
meta = self.dag_ir.get_node_meta(node_name)
|
|
meta.tensor = {"tensor": example}
|
|
|
|
def set_store_tensor(self, node_name, example):
|
|
"""
|
|
Add an example tensor to node {node_name} in the DAG IR
|
|
"""
|
|
meta = self.dag_ir.get_node_meta(node_name)
|
|
meta.store_tensor = {"tensor": example}
|
|
|
|
def mark_output(self, node_name):
|
|
"""
|
|
Mark a store node as output
|
|
"""
|
|
meta = self.dag_ir.get_node_meta(node_name)
|
|
if not isinstance(meta, StoreNode):
|
|
raise ValueError(
|
|
f"Only StoreNodes can be marked as output. "
|
|
f"Got {type(meta).__name__}: {node_name}")
|
|
meta.is_output = True
|
|
|
|
# Add node with specific type
|
|
|
|
def add_load_node(self, name, example):
|
|
"""
|
|
Add a Load node to DAG IR
|
|
:param name: name of the loaded variable
|
|
:type name: str
|
|
:param example: example input
|
|
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
|
"""
|
|
if name is None:
|
|
raise ValueError(f"Name is not provided.")
|
|
if example is None:
|
|
raise ValueError(f"Example input for {name} is not provided.")
|
|
load_node = LoadNode(name)
|
|
load_node.tensor = {"tensor": example}
|
|
# Special logics for accumulator
|
|
if name == "accum":
|
|
if load_node.tensor.rank == 2:
|
|
new_shape = tuple([1, ] + list(load_node.tensor.shape))
|
|
load_node.tensor.broadcast(new_shape)
|
|
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
|
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
|
|
self.add_node(load_node)
|
|
|
|
def add_imm(self, value: Union[float,int]):
|
|
"""
|
|
Add an immediate scalar value to DAG IR
|
|
:param value: the value of the immediate scalar
|
|
:type value: float
|
|
"""
|
|
try:
|
|
value = float(value)
|
|
except:
|
|
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
|
|
|
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
|
self.imm_cnt += 1
|
|
load_node = LoadNode(name)
|
|
load_node.tensor = {"tensor": value, "is_constant": True}
|
|
self.add_node(load_node)
|
|
return name
|
|
|
|
def add_compute_node(self, op, name=None):
|
|
"""
|
|
Add a compute node.
|
|
:param op: the computation op
|
|
:param name: the node name (optional)
|
|
:type name: str
|
|
:return: the name of the compute node
|
|
"""
|
|
if name is None:
|
|
name = f"compute_{self.compute_cnt}"
|
|
self.compute_cnt += 1
|
|
compute_node = ComputeNode(
|
|
name=name, fn=op,
|
|
element_output=self.element_compute,
|
|
element_compute=self.element_compute)
|
|
self.add_node(compute_node)
|
|
return compute_node.name
|
|
|
|
def add_layout_node(self, op, kwargs, name=None):
|
|
"""
|
|
Add a layout node.
|
|
:param op: the layout op
|
|
:type op: evt_ops
|
|
:param name: the node name (optional)
|
|
:type name: str
|
|
:return: the name of the layout node
|
|
"""
|
|
if name is None:
|
|
name = f"layout_{self.layout_cnt}"
|
|
self.layout_cnt += 1
|
|
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
|
self.add_node(layout_node)
|
|
return layout_node.name
|
|
|
|
def add_store_node(self, name):
|
|
store_node = StoreNode(name)
|
|
self.add_node(store_node)
|
|
|
|
#
|
|
# Visualization The DAG IR
|
|
#
|
|
|
|
def visualize(self, name="dag_ir"):
|
|
"""
|
|
Visualize the dag ir with svg file
|
|
:param name: the name of the graph
|
|
"""
|
|
drawer = EVTGraphDrawer(self.dag_ir, name)
|
|
try:
|
|
for name, graph in drawer.get_dot_graph():
|
|
graph.write_svg(f"./{name}.svg")
|
|
except:
|
|
raise RuntimeError(
|
|
"'dot' is not found in path. GraphDrawer is disabled. "
|
|
"Please install it with 'sudo apt-get install graphviz'."
|
|
)
|
|
|
|
#
|
|
# Get shared memory size
|
|
#
|
|
|
|
def get_smem_size(self, tile_description):
|
|
"""
|
|
Get the shared memory size of the epilogue
|
|
"""
|
|
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
|
return smem_size
|