195 lines
7.0 KiB
Python
195 lines
7.0 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
Python AST frontend that parses input into DAG IR
|
|
"""
|
|
|
|
import ast
|
|
import inspect
|
|
import textwrap
|
|
|
|
from cutlass_library import DataType
|
|
|
|
import cutlass_cppgen
|
|
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
|
|
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
|
from cutlass_cppgen.backend.library import FunctionalOp
|
|
|
|
|
|
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
|
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
|
super().__init__(cc, element_compute, **kwargs)
|
|
# Flags
|
|
# If this state is True, visit_Constant returns values without creating imm node
|
|
self.no_imm = False
|
|
self.visiting_return = False
|
|
|
|
def parse(self, example_inputs):
|
|
self.example_inputs = example_inputs
|
|
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
|
self.ast = ast.parse(self.source)
|
|
self.visit(self.ast)
|
|
|
|
#
|
|
# Helper functions
|
|
#
|
|
@staticmethod
|
|
def ast_op_to_bindings(op):
|
|
mapping = {
|
|
ast.Add: FunctionalOp.Plus,
|
|
ast.Sub: FunctionalOp.Minus,
|
|
ast.Mult: FunctionalOp.Multiplies,
|
|
ast.Div: FunctionalOp.Divides,
|
|
"maximum": FunctionalOp.Maximum,
|
|
"minimum": FunctionalOp.Minimum,
|
|
"identity": identity.binding_type,
|
|
"relu": relu.binding_type,
|
|
"tanh": tanh.binding_type,
|
|
"sigmoid": sigmoid.binding_type,
|
|
"silu": silu.binding_type,
|
|
"hardswish": hardswish.binding_type,
|
|
"gelu": gelu.binding_type,
|
|
"multiply_add": FunctionalOp.MultiplyAdd,
|
|
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
|
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
|
"exp": FunctionalOp.Exp
|
|
}
|
|
return mapping[op]
|
|
|
|
#
|
|
# Visiting different node types
|
|
#
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
# Visit args and register load nodes
|
|
for arg in node.args.args:
|
|
self.visit(arg)
|
|
for expr in node.body:
|
|
self.visit(expr)
|
|
|
|
def visit_arg(self, node: ast.arg):
|
|
# Name of the argument
|
|
name = node.arg
|
|
try:
|
|
example_tensor = self.example_inputs[name]
|
|
except:
|
|
raise RuntimeError(f"Example input for {name} is not provided.")
|
|
|
|
self.add_load_node(name, example_tensor)
|
|
|
|
def visit_Name(self, node: ast.Name):
|
|
return node.id
|
|
|
|
def visit_Constant(self, node: ast.Constant):
|
|
if self.no_imm:
|
|
return node.value
|
|
else:
|
|
name = self.add_imm(node.value)
|
|
return name
|
|
|
|
def visit_Tuple(self, node: ast.Tuple):
|
|
results = []
|
|
for elt in node.elts:
|
|
results.append(self.visit(elt))
|
|
return tuple(results)
|
|
|
|
def visit_keyword(self, node: ast.keyword):
|
|
return {node.arg: self.visit(node.value)}
|
|
|
|
def visit_BinOp(self, node: ast.BinOp):
|
|
if self.visiting_return:
|
|
raise SyntaxError("Return value cannot be an expression")
|
|
lhs = self.visit(node.left)
|
|
rhs = self.visit(node.right)
|
|
op = self.ast_op_to_bindings(type(node.op))
|
|
name = self.add_compute_node(op)
|
|
|
|
# Add edges
|
|
# The edge weights are used to sort the input args
|
|
self.add_edge(lhs, name, weight=0)
|
|
self.add_edge(rhs, name, weight=1)
|
|
return name
|
|
|
|
def visit_Assign(self, node: ast.BinOp):
|
|
target = self.visit(node.targets[0])
|
|
value = self.visit(node.value)
|
|
# Create the assign node
|
|
self.add_store_node(target)
|
|
|
|
# Add edges
|
|
self.add_edge(value, target)
|
|
return target
|
|
|
|
def visit_Call(self, node: ast.Call):
|
|
if self.visiting_return:
|
|
raise SyntaxError("Return value cannot be an expression")
|
|
func = self.visit(node.func)
|
|
args = [self.visit(arg) for arg in node.args]
|
|
|
|
if func in self.layout_fns.keys():
|
|
# Parse kwargs
|
|
# By default, visiting imm automatically creates a load node
|
|
# However, in function call, keyword args are used to set
|
|
# specific function attributes such as indices for permute
|
|
# So no_imm is set to True temporarily
|
|
self.no_imm = True
|
|
kwargs = {}
|
|
for kw in node.keywords:
|
|
kwargs.update(self.visit(kw))
|
|
self.no_imm = False
|
|
op = self.layout_fns[func]
|
|
name = self.add_layout_node(op, kwargs)
|
|
else:
|
|
op = self.ast_op_to_bindings(func)
|
|
name = self.add_compute_node(op)
|
|
|
|
# Add edges
|
|
for idx, arg in enumerate(args):
|
|
self.add_edge(arg, name, weight=idx)
|
|
return name
|
|
|
|
def visit_Return(self, node: ast.Return):
|
|
self.visiting_return = True
|
|
results = self.visit(node.value)
|
|
self.visiting_return = False
|
|
self.return_names = results
|
|
if not isinstance(results, tuple):
|
|
results = (results,)
|
|
for rst in results:
|
|
try:
|
|
example_tensor = self.example_inputs[rst]
|
|
except:
|
|
raise RuntimeError(f"Example input for {rst} is not provided.")
|
|
self.set_store_tensor(rst, example_tensor)
|
|
self.mark_output(rst)
|