v4.2 tag release. (#2638)
This commit is contained in:
@ -207,7 +207,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
||||
|
||||
if dst_width == src_width:
|
||||
return a
|
||||
elif src_signed and not dst_signed:
|
||||
elif src_signed != False and not dst_signed:
|
||||
# Signed -> Unsigned
|
||||
if dst_width > src_width:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
@ -216,7 +216,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
||||
elif src_signed == dst_signed:
|
||||
# Same signedness
|
||||
if dst_width > src_width:
|
||||
if src_signed and src_width > 1:
|
||||
if src_signed != False and src_width > 1:
|
||||
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
@ -479,7 +479,7 @@ class ArithValue(ir.Value):
|
||||
if self.is_float:
|
||||
q = arith.divf(self, other, loc=loc, ip=ip)
|
||||
return math.floor(q, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.floordivsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.divui(self, other, loc=loc, ip=ip)
|
||||
@ -489,7 +489,7 @@ class ArithValue(ir.Value):
|
||||
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.remf(self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.remsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.remui(self, other, loc=loc, ip=ip)
|
||||
@ -524,7 +524,7 @@ class ArithValue(ir.Value):
|
||||
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
|
||||
@ -534,7 +534,7 @@ class ArithValue(ir.Value):
|
||||
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
|
||||
@ -561,7 +561,7 @@ class ArithValue(ir.Value):
|
||||
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
|
||||
@ -571,7 +571,7 @@ class ArithValue(ir.Value):
|
||||
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
|
||||
@ -599,7 +599,7 @@ class ArithValue(ir.Value):
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.signed:
|
||||
if self.signed != False:
|
||||
return arith.shrsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.shrui(self, other, loc=loc, ip=ip)
|
||||
@ -633,7 +633,7 @@ class ArithValue(ir.Value):
|
||||
return super().__hash__()
|
||||
|
||||
def __str__(self):
|
||||
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
|
||||
return "?"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
@ -657,7 +657,7 @@ def _min(lhs, rhs, *, loc=None, ip=None):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
if lhs.signed != False:
|
||||
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.minui(lhs, rhs, loc=loc, ip=ip)
|
||||
@ -683,7 +683,7 @@ def _max(lhs, rhs, *, loc=None, ip=None):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
if lhs.signed != False:
|
||||
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.maxui(lhs, rhs, loc=loc, ip=ip)
|
||||
|
||||
@ -17,12 +17,16 @@ The preprocessor read through python's ast and changes the input code.
|
||||
from typing import Callable, Iterator, Optional, overload
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
import inspect
|
||||
from types import BuiltinFunctionType
|
||||
from functools import lru_cache
|
||||
|
||||
from .utils.logger import log
|
||||
from .common import *
|
||||
|
||||
from ._mlir_helpers.arith import ArithValue
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
The Executor class handles dynamic and compile-time (constexpr) execution
|
||||
@ -45,9 +49,11 @@ class Executor:
|
||||
self._compare_executor = None
|
||||
self._any_executor = None
|
||||
self._all_executor = None
|
||||
self._builtin_redirector = None
|
||||
|
||||
def set_functions(
|
||||
self,
|
||||
*,
|
||||
is_dynamic_expression: Callable,
|
||||
loop_execute_range_dynamic: Callable,
|
||||
if_dynamic: Callable,
|
||||
@ -55,6 +61,7 @@ class Executor:
|
||||
compare_executor: Callable,
|
||||
any_executor: Callable = None,
|
||||
all_executor: Callable = None,
|
||||
builtin_redirector: Callable = None,
|
||||
):
|
||||
self._is_dynamic_expression = is_dynamic_expression
|
||||
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
||||
@ -63,6 +70,7 @@ class Executor:
|
||||
self._compare_executor = compare_executor
|
||||
self._any_executor = any_executor
|
||||
self._all_executor = all_executor
|
||||
self._builtin_redirector = builtin_redirector
|
||||
|
||||
@staticmethod
|
||||
def convert_to_list(x):
|
||||
@ -90,42 +98,18 @@ class Executor:
|
||||
return res[0]
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def for_constexpr(
|
||||
func: Callable,
|
||||
start: int,
|
||||
stop: int,
|
||||
step: int,
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
):
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
loop_results = iter_args
|
||||
log().debug("iter_args [%s]", iter_args)
|
||||
for i in range(start, stop, step):
|
||||
log().debug("i [%s] iter_args [%s]", i, iter_args)
|
||||
loop_results = func(i, *used_args, *loop_results)
|
||||
log().debug("loop_results [%s]", loop_results)
|
||||
if loop_results is None:
|
||||
loop_results = []
|
||||
if not isinstance(loop_results, list):
|
||||
loop_results = [loop_results]
|
||||
|
||||
log().debug("done loop_results [%s]", loop_results)
|
||||
return Executor.converge_ret_val(loop_results)
|
||||
|
||||
def for_execute(
|
||||
self,
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
):
|
||||
assert (
|
||||
self._loop_execute_range_dynamic
|
||||
@ -137,12 +121,12 @@ class Executor:
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
|
||||
def if_execute(
|
||||
@ -150,15 +134,20 @@ class Executor:
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
assert self._if_dynamic, "Functions must be set before execution."
|
||||
|
||||
# MLIR generation
|
||||
return self._if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
pred,
|
||||
then_block,
|
||||
else_block,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
def while_execute(
|
||||
@ -166,9 +155,9 @@ class Executor:
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
assert self._while_dynamic, "Functions must be set before execution."
|
||||
|
||||
@ -176,9 +165,9 @@ class Executor:
|
||||
return self._while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@ -194,23 +183,24 @@ def loop_selector(
|
||||
stop,
|
||||
step,
|
||||
*,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
):
|
||||
log().debug(
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
|
||||
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
from .typing import Integer, Numeric
|
||||
|
||||
@ -230,19 +220,19 @@ def loop_selector(
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def if_selector(pred, used_args=[], yield_args=[]):
|
||||
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
||||
def if_selector(pred, write_args=[]):
|
||||
log().debug("pred [%s] write_args [%s]", pred, write_args)
|
||||
# Handle Numeric types here?
|
||||
|
||||
from .typing import Numeric
|
||||
@ -251,14 +241,14 @@ def if_selector(pred, used_args=[], yield_args=[]):
|
||||
pred = pred.value
|
||||
|
||||
def ir_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
return func(pred, *write_args)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def while_selector(pred, used_args=[], yield_args=[]):
|
||||
def while_selector(pred, write_args=[]):
|
||||
def ir_while_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
return func(pred, *write_args)
|
||||
|
||||
return ir_while_loop
|
||||
|
||||
@ -267,17 +257,17 @@ def while_executor(
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
return executor.while_execute(
|
||||
pred,
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@ -285,12 +275,17 @@ def if_executor(
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
return executor.if_execute(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
pred,
|
||||
then_block,
|
||||
else_block,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@ -313,14 +308,17 @@ class range:
|
||||
|
||||
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
||||
- unroll_full: Whether to fully unroll the loop
|
||||
- pipelining: Compiler generated pipeline configuration
|
||||
- prefetch_stages: Number of prefetch stages to generate
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
|
||||
def __new__(
|
||||
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
|
||||
):
|
||||
pass
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@ -340,6 +338,7 @@ def range_dynamic(*args, **kwargs):
|
||||
def range_constexpr(*args):
|
||||
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# If expressions
|
||||
# =============================================================================
|
||||
@ -405,7 +404,7 @@ def assert_executor(test, msg=None):
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please replace with runtime assert."
|
||||
suggestion="Please replace with runtime assert.",
|
||||
)
|
||||
|
||||
|
||||
@ -413,10 +412,11 @@ def bool_cast(value):
|
||||
if executor._is_dynamic_expression(value):
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
||||
suggestion="Please explicitly convert to boolean with expressions like comparision.",
|
||||
)
|
||||
return bool(value)
|
||||
|
||||
|
||||
def compare_executor(left, comparators, ops):
|
||||
"""
|
||||
Executes comparison operations with a left operand and a list of comparators.
|
||||
@ -470,6 +470,19 @@ def all_executor(iterable):
|
||||
# =============================================================================
|
||||
# Control flow checks
|
||||
# =============================================================================
|
||||
class DSLOptimizationWarning(Warning):
|
||||
"""
|
||||
This warning is used to warn the user about the optimization related issues in DSL.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__()
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
def range_value_check(*args):
|
||||
"""
|
||||
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
||||
@ -495,7 +508,7 @@ def range_value_check(*args):
|
||||
if range_length >= 64:
|
||||
warnings.warn(
|
||||
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
||||
category=UserWarning,
|
||||
category=DSLOptimizationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@ -519,7 +532,50 @@ def range_perf_warning(filename, lineno, *args):
|
||||
"This loop is no longer unrolled and may cause performance regression. "
|
||||
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
||||
),
|
||||
category=UserWarning,
|
||||
category=DSLOptimizationWarning,
|
||||
filename=filename,
|
||||
lineno=lineno,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_self_module():
|
||||
"""
|
||||
This function is used to get the owning module of this function.
|
||||
"""
|
||||
return inspect.getmodule(_get_self_module)
|
||||
|
||||
|
||||
def cf_symbol_check(symbol):
|
||||
"""
|
||||
Check if the symbol is control flow symbol from current module.
|
||||
"""
|
||||
|
||||
failed = False
|
||||
name = symbol.__name__
|
||||
self_module = _get_self_module()
|
||||
if inspect.ismodule(symbol):
|
||||
name = "range"
|
||||
if not self_module.__name__.startswith(symbol.__name__):
|
||||
failed = True
|
||||
else:
|
||||
owning_module = inspect.getmodule(symbol)
|
||||
if owning_module != self_module:
|
||||
failed = True
|
||||
|
||||
if failed:
|
||||
raise DSLRuntimeError(
|
||||
f"Incorrect {symbol.__name__} is used.",
|
||||
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
|
||||
)
|
||||
|
||||
|
||||
def redirect_builtin_function(fcn):
|
||||
"""
|
||||
This function is used to redirect built-in function call
|
||||
to the function defined in DSL package.
|
||||
"""
|
||||
# Only redirect if it's a built-in
|
||||
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
||||
return executor._builtin_redirector(fcn)
|
||||
return fcn
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -139,8 +139,7 @@ def dump_cache_to_path(
|
||||
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
|
||||
):
|
||||
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
original_path = os.getcwd()
|
||||
try:
|
||||
os.chdir(path)
|
||||
|
||||
@ -205,6 +205,8 @@ class CompileOptions:
|
||||
self._parser.add_argument(
|
||||
"--enable-device-assertions", action="store_true", default=False
|
||||
)
|
||||
self._parser.add_argument("--link-libraries", type=str, default="")
|
||||
|
||||
try:
|
||||
self._options = self._parser.parse_args(options.split())
|
||||
except SystemExit as e:
|
||||
|
||||
@ -32,13 +32,14 @@ import hashlib
|
||||
from functools import lru_cache, wraps
|
||||
from collections import namedtuple
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union, Tuple, get_origin, get_args
|
||||
from types import FunctionType
|
||||
from typing import Any, Union, Tuple, get_origin, get_args, List
|
||||
from types import FunctionType, SimpleNamespace
|
||||
import warnings
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
from .compiler import CompileOptions
|
||||
from .ast_helpers import DSLOptimizationWarning
|
||||
|
||||
# =============================================================================
|
||||
# CUDA Python
|
||||
@ -56,7 +57,7 @@ from .utils.timer import timer
|
||||
from .utils.logger import setup_log, log
|
||||
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
|
||||
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
|
||||
from .runtime.tensor_descriptor import TensorDescriptor
|
||||
|
||||
from .ast_preprocessor import DSLPreprocessor
|
||||
from .common import *
|
||||
from .typing import (
|
||||
@ -73,12 +74,6 @@ from .._mlir import runtime as rt
|
||||
from .._mlir.extras import types as T
|
||||
from .._mlir.dialects import arith, math, func
|
||||
|
||||
# =============================================================================
|
||||
# cutlass.dlpack_runtime
|
||||
# =============================================================================
|
||||
|
||||
from .runtime.dlpack_runtime import dlpack_to_tensor_desc, mark_layout_dynamic
|
||||
|
||||
# =============================================================================
|
||||
# Global Variables
|
||||
# =============================================================================
|
||||
@ -177,6 +172,7 @@ def is_dynamic_expression(value):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_mlir_values(obj):
|
||||
"""
|
||||
Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
|
||||
@ -186,6 +182,10 @@ def extract_mlir_values(obj):
|
||||
res = obj.__extract_mlir_values__()
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
res = sum((extract_mlir_values(x) for x in obj), [])
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
res = []
|
||||
for k, v in obj.__dict__.items():
|
||||
res.extend(extract_mlir_values(v))
|
||||
# Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
@ -215,6 +215,13 @@ def new_from_mlir_values(obj, values):
|
||||
values = values[n_items:]
|
||||
obj_ty = type(obj)
|
||||
return obj_ty(res)
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
res = SimpleNamespace()
|
||||
for k, v in obj.__dict__.items():
|
||||
n_items = len(get_mlir_types(v))
|
||||
res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
|
||||
values = values[n_items:]
|
||||
return res
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
"Sets are not supported in new_from_mlir_values to ensure order preservation",
|
||||
@ -249,8 +256,6 @@ class DSLCallable:
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
get_arg_spec(): Returns the argument specification of the function.
|
||||
get_signature(): Returns the signature of the function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
@ -266,23 +271,23 @@ class DSLCallable:
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __signature__(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.__func__.__name__
|
||||
|
||||
def get_arg_spec(self):
|
||||
return inspect.getfullargspec(self.__func__)
|
||||
|
||||
def get_signature(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
gpu_module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
dsl_package_name: List[str],
|
||||
compiler_provider: Any,
|
||||
pass_sm_arch_name: str,
|
||||
device_compilation_only=False,
|
||||
@ -293,6 +298,7 @@ class BaseDSL:
|
||||
|
||||
Parameters:
|
||||
- name (str): Name of DSL, used for environment variables and logging.
|
||||
- package_name (str): Name of the package, used for the preprocessor.
|
||||
- compiler_provider (MLIR dialect): Provider for compiler.
|
||||
- pass_sm_arch_name (str): The keyword name of the SM.
|
||||
- device_compilation_only (bool) : Only device code, and call it via cuda driver
|
||||
@ -330,6 +336,9 @@ class BaseDSL:
|
||||
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
|
||||
|
||||
# set warning
|
||||
if not self.envar.enable_optimization_warnings:
|
||||
# By default, optimization warnings are disabled
|
||||
warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
|
||||
if self.envar.warnings_as_errors:
|
||||
warnings.filterwarnings("error")
|
||||
if self.envar.warnings_ignore:
|
||||
@ -355,7 +364,7 @@ class BaseDSL:
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
if preprocess:
|
||||
self.preprocessor = DSLPreprocessor()
|
||||
self.preprocessor = DSLPreprocessor(dsl_package_name)
|
||||
log().info(f"Initializing {name} DSL")
|
||||
log().debug(f"Logger initialized for {self.name}")
|
||||
|
||||
@ -656,7 +665,7 @@ class BaseDSL:
|
||||
return ir_args, ir_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _generate_mlir_type_for_tensor_descriptor(self, tensor: TensorDescriptor):
|
||||
def _generate_mlir_type_for_tensor_descriptor(self, tensor):
|
||||
"""
|
||||
Generate MLIR type for the tensor descriptor.
|
||||
"""
|
||||
@ -671,13 +680,6 @@ class BaseDSL:
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_module_globals(self):
|
||||
"""
|
||||
Get the module's globals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_globals(self):
|
||||
"""
|
||||
Combines global and local variables from the current context and the
|
||||
@ -690,43 +692,21 @@ class BaseDSL:
|
||||
AST preprocessor generates a new python code, so the resulting globals
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
all_globals = self._get_module_globals().copy()
|
||||
all_globals = {}
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
return all_globals
|
||||
|
||||
@abstractmethod
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _handle_tensor_descriptor(
|
||||
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
|
||||
) -> TensorDescriptor:
|
||||
if self._is_tensor_descriptor(maybe_tensor):
|
||||
tensor = (
|
||||
maybe_tensor
|
||||
if isinstance(maybe_tensor, TensorDescriptor)
|
||||
else TensorDescriptor(maybe_tensor)
|
||||
)
|
||||
if need_gpu_memory and not tensor.is_in_device:
|
||||
log().info(
|
||||
"FAIL name=[%s] tensor=[%s] in_gpu=[%s]",
|
||||
arg_name,
|
||||
tensor,
|
||||
tensor.is_in_device,
|
||||
)
|
||||
raise DSLRuntimeError(
|
||||
f'Tensor "{arg_name}" is tensor "{tensor}" '
|
||||
"is not in the GPU memory. "
|
||||
)
|
||||
|
||||
return tensor
|
||||
|
||||
raise DSLRuntimeError(
|
||||
f"Argument {arg_name} could not be transformed into a TensorDescriptor."
|
||||
)
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
|
||||
"""
|
||||
@ -882,10 +862,11 @@ class BaseDSL:
|
||||
cluster: list = None
|
||||
grid: list = field(default_factory=lambda: [1, 1, 1])
|
||||
block: list = field(default_factory=lambda: [1, 1, 1])
|
||||
smem: int = 0
|
||||
smem: int = None
|
||||
async_deps: list = field(default_factory=list)
|
||||
has_cluster: bool = False
|
||||
min_blocks_per_mp: int = 0
|
||||
auto_smem: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.grid) != 3:
|
||||
@ -893,6 +874,10 @@ class BaseDSL:
|
||||
if len(self.block) != 3:
|
||||
raise DSLRuntimeError(f"Expect 3d block!")
|
||||
|
||||
if self.smem is None:
|
||||
self.smem = 0
|
||||
self.auto_smem = True
|
||||
|
||||
self.has_cluster = self.cluster is not None
|
||||
if self.cluster is None:
|
||||
self.cluster = [None, None, None]
|
||||
@ -1116,8 +1101,6 @@ class BaseDSL:
|
||||
try:
|
||||
result = funcBody(*ir_args, **ir_kwargs)
|
||||
func.ReturnOp([])
|
||||
except DSLAstPreprocessorError as pp_error:
|
||||
raise pp_error
|
||||
except NameError as name_error:
|
||||
raise DSLRuntimeError(
|
||||
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
|
||||
@ -1127,11 +1110,6 @@ class BaseDSL:
|
||||
except DSLRuntimeError as dsl_error:
|
||||
# Throw it's already a DSL error
|
||||
raise dsl_error
|
||||
except Exception as general_e:
|
||||
# Transform internal error to a DSL error
|
||||
raise DSLRuntimeError(
|
||||
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥"
|
||||
) from general_e
|
||||
return module, result
|
||||
|
||||
# Build IR module
|
||||
@ -1328,10 +1306,8 @@ class BaseDSL:
|
||||
raise DSLRuntimeError("Function body is not set.")
|
||||
|
||||
# Pass the actual function object to inspect.signature to get the signature.
|
||||
if isinstance(self.funcBody, DSLCallable):
|
||||
sig = self.funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(self.funcBody)
|
||||
sig = inspect.signature(self.funcBody)
|
||||
|
||||
function_name = self.funcBody.__name__
|
||||
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
@ -1382,10 +1358,7 @@ class BaseDSL:
|
||||
# Check the number of arguments
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
@ -1447,7 +1420,7 @@ class BaseDSL:
|
||||
return cuda_helpers.stream_create()
|
||||
|
||||
def _execute_cuda(
|
||||
self, fname_cubin, kernel_name, grid_size, block_size, stream=None
|
||||
self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
|
||||
):
|
||||
"""
|
||||
Executes a specified CUDA kernel from a cubin file, handling module loading,
|
||||
@ -1471,7 +1444,7 @@ class BaseDSL:
|
||||
grid_size,
|
||||
block_size,
|
||||
stream,
|
||||
smem_size=16000,
|
||||
smem_size=smem_size,
|
||||
kernel_args=self.exe_args,
|
||||
)
|
||||
|
||||
@ -1480,7 +1453,13 @@ class BaseDSL:
|
||||
cuda_helpers.stream_sync(stream)
|
||||
|
||||
def _execute_by_cuda_driver(
|
||||
self, kernel_generator, generate_cubin, grid_size, block_size, stream=None
|
||||
self,
|
||||
kernel_generator,
|
||||
generate_cubin,
|
||||
grid_size,
|
||||
block_size,
|
||||
smem_size,
|
||||
stream=None,
|
||||
):
|
||||
"""
|
||||
This function builds IR and execute the module using cuda driver.
|
||||
@ -1511,10 +1490,9 @@ class BaseDSL:
|
||||
fname_cubin = generate_cubin(module, kernel_name)
|
||||
|
||||
# Execute a cuda kernel from cubin
|
||||
if block_size is None:
|
||||
# The TileIR driver should set this automatically.
|
||||
block_size = self.block_size
|
||||
self._execute_cuda(fname_cubin, kernel_name, grid_size, block_size, stream)
|
||||
self._execute_cuda(
|
||||
fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@ -1587,10 +1565,7 @@ class BaseDSL:
|
||||
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
||||
|
||||
kernel_name = funcBody.__name__
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
self.funcBody = funcBody
|
||||
|
||||
# Give each kernel a unique name. (The same kernel may be
|
||||
|
||||
@ -58,6 +58,11 @@ def get_int_env_var(var_name, default_value=0):
|
||||
return int(value) if value and value.isdigit() else default_value
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def has_env_var(var_name):
|
||||
return os.getenv(var_name) is not None
|
||||
|
||||
|
||||
def detect_gpu_arch(prefix):
|
||||
"""
|
||||
Attempts to detect the machine's GPU architecture.
|
||||
@ -256,6 +261,7 @@ class EnvironmentVarManager:
|
||||
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
|
||||
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
|
||||
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
|
||||
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
|
||||
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
|
||||
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
|
||||
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
|
||||
@ -267,7 +273,6 @@ class EnvironmentVarManager:
|
||||
self.prefix = prefix # change if needed
|
||||
|
||||
# Printing options
|
||||
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
||||
self.print_after_preprocessor = get_bool_env_var(
|
||||
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
|
||||
)
|
||||
@ -275,15 +280,29 @@ class EnvironmentVarManager:
|
||||
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
|
||||
# File options
|
||||
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
|
||||
# Logging options
|
||||
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
||||
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
|
||||
# Other options
|
||||
if (
|
||||
has_env_var(f"{prefix}_LOG_LEVEL")
|
||||
and not self.log_to_console
|
||||
and not self.log_to_file
|
||||
):
|
||||
log().warning(
|
||||
f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
|
||||
)
|
||||
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
|
||||
|
||||
# Other options
|
||||
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
|
||||
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
|
||||
self.warnings_as_errors = get_bool_env_var(
|
||||
f"{prefix}_WARNINGS_AS_ERRORS", False
|
||||
)
|
||||
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
|
||||
self.enable_optimization_warnings = get_bool_env_var(
|
||||
f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
|
||||
)
|
||||
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
|
||||
self.disable_file_caching = get_bool_env_var(
|
||||
f"{prefix}_DISABLE_FILE_CACHING", False
|
||||
|
||||
@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for
|
||||
the DSL.
|
||||
"""
|
||||
|
||||
from . import device_tensor
|
||||
from . import dlpack_types
|
||||
from . import cuda
|
||||
from . import tensor_descriptor
|
||||
from . import jit_arg_adapters
|
||||
|
||||
__all__ = [
|
||||
"device_tensor",
|
||||
"dlpack_types",
|
||||
"cuda",
|
||||
"tensor_descriptor",
|
||||
"jit_arg_adapters",
|
||||
]
|
||||
|
||||
@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name):
|
||||
return kernel
|
||||
|
||||
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
|
||||
"""
|
||||
Launches the CUDA kernel.
|
||||
"""
|
||||
|
||||
@ -183,6 +183,13 @@ class TensorDescriptor:
|
||||
"""
|
||||
return self.device_type == _dpack.DLDeviceType.kDLGPU
|
||||
|
||||
@staticmethod
|
||||
def is_compatible(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor or can be converted to one."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
|
||||
def from_tensor(tensor) -> TensorDescriptor:
|
||||
"""Create a TensorDescriptor from a tensor object."""
|
||||
@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor:
|
||||
def to_tensor(tensor_descriptor: TensorDescriptor):
|
||||
"""Return tensor object from tensor descriptor."""
|
||||
return tensor_descriptor.tensor
|
||||
|
||||
|
||||
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
Reference in New Issue
Block a user