v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -207,7 +207,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
if dst_width == src_width:
return a
elif src_signed and not dst_signed:
elif src_signed != False and not dst_signed:
# Signed -> Unsigned
if dst_width > src_width:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
@ -216,7 +216,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
elif src_signed == dst_signed:
# Same signedness
if dst_width > src_width:
if src_signed and src_width > 1:
if src_signed != False and src_width > 1:
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
else:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
@ -479,7 +479,7 @@ class ArithValue(ir.Value):
if self.is_float:
q = arith.divf(self, other, loc=loc, ip=ip)
return math.floor(q, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.floordivsi(self, other, loc=loc, ip=ip)
else:
return arith.divui(self, other, loc=loc, ip=ip)
@ -489,7 +489,7 @@ class ArithValue(ir.Value):
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.remf(self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.remsi(self, other, loc=loc, ip=ip)
else:
return arith.remui(self, other, loc=loc, ip=ip)
@ -524,7 +524,7 @@ class ArithValue(ir.Value):
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
@ -534,7 +534,7 @@ class ArithValue(ir.Value):
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
@ -561,7 +561,7 @@ class ArithValue(ir.Value):
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
@ -571,7 +571,7 @@ class ArithValue(ir.Value):
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
@ -599,7 +599,7 @@ class ArithValue(ir.Value):
@_dispatch_to_rhs_r_op
@_binary_op
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.signed:
if self.signed != False:
return arith.shrsi(self, other, loc=loc, ip=ip)
else:
return arith.shrui(self, other, loc=loc, ip=ip)
@ -633,7 +633,7 @@ class ArithValue(ir.Value):
return super().__hash__()
def __str__(self):
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
return "?"
def __repr__(self):
return self.__str__()
@ -657,7 +657,7 @@ def _min(lhs, rhs, *, loc=None, ip=None):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
if lhs.signed != False:
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.minui(lhs, rhs, loc=loc, ip=ip)
@ -683,7 +683,7 @@ def _max(lhs, rhs, *, loc=None, ip=None):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
if lhs.signed != False:
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.maxui(lhs, rhs, loc=loc, ip=ip)

View File

@ -17,12 +17,16 @@ The preprocessor read through python's ast and changes the input code.
from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings
import inspect
from types import BuiltinFunctionType
from functools import lru_cache
from .utils.logger import log
from .common import *
from ._mlir_helpers.arith import ArithValue
class Executor:
"""
The Executor class handles dynamic and compile-time (constexpr) execution
@ -45,9 +49,11 @@ class Executor:
self._compare_executor = None
self._any_executor = None
self._all_executor = None
self._builtin_redirector = None
def set_functions(
self,
*,
is_dynamic_expression: Callable,
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
@ -55,6 +61,7 @@ class Executor:
compare_executor: Callable,
any_executor: Callable = None,
all_executor: Callable = None,
builtin_redirector: Callable = None,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
@ -63,6 +70,7 @@ class Executor:
self._compare_executor = compare_executor
self._any_executor = any_executor
self._all_executor = all_executor
self._builtin_redirector = builtin_redirector
@staticmethod
def convert_to_list(x):
@ -90,42 +98,18 @@ class Executor:
return res[0]
return res
@staticmethod
def for_constexpr(
func: Callable,
start: int,
stop: int,
step: int,
used_args: list,
iter_args: list,
):
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
log().debug("i [%s] iter_args [%s]", i, iter_args)
loop_results = func(i, *used_args, *loop_results)
log().debug("loop_results [%s]", loop_results)
if loop_results is None:
loop_results = []
if not isinstance(loop_results, list):
loop_results = [loop_results]
log().debug("done loop_results [%s]", loop_results)
return Executor.converge_ret_val(loop_results)
def for_execute(
self,
func,
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
prefetch_stages=None,
):
assert (
self._loop_execute_range_dynamic
@ -137,12 +121,12 @@ class Executor:
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
def if_execute(
@ -150,15 +134,20 @@ class Executor:
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
assert self._if_dynamic, "Functions must be set before execution."
# MLIR generation
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
pred,
then_block,
else_block,
write_args,
full_write_args_count,
write_args_names,
)
def while_execute(
@ -166,9 +155,9 @@ class Executor:
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
assert self._while_dynamic, "Functions must be set before execution."
@ -176,9 +165,9 @@ class Executor:
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
write_args,
full_write_args_count,
write_args_names,
)
@ -194,23 +183,24 @@ def loop_selector(
stop,
step,
*,
used_args=[],
iter_args=[],
iter_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
prefetch_stages=None,
):
log().debug(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
start,
stop,
step,
used_args,
iter_args,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
from .typing import Integer, Numeric
@ -230,19 +220,19 @@ def loop_selector(
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
return ir_loop
def if_selector(pred, used_args=[], yield_args=[]):
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
def if_selector(pred, write_args=[]):
log().debug("pred [%s] write_args [%s]", pred, write_args)
# Handle Numeric types here?
from .typing import Numeric
@ -251,14 +241,14 @@ def if_selector(pred, used_args=[], yield_args=[]):
pred = pred.value
def ir_loop(func):
return func(pred, *used_args, *yield_args)
return func(pred, *write_args)
return ir_loop
def while_selector(pred, used_args=[], yield_args=[]):
def while_selector(pred, write_args=[]):
def ir_while_loop(func):
return func(pred, *used_args, *yield_args)
return func(pred, *write_args)
return ir_while_loop
@ -267,17 +257,17 @@ def while_executor(
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
return executor.while_execute(
pred,
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
write_args,
full_write_args_count,
write_args_names,
)
@ -285,12 +275,17 @@ def if_executor(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
pred,
then_block,
else_block,
write_args,
full_write_args_count,
write_args_names,
)
@ -313,14 +308,17 @@ class range:
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
- unroll_full: Whether to fully unroll the loop
- pipelining: Compiler generated pipeline configuration
- prefetch_stages: Number of prefetch stages to generate
"""
@overload
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
def __new__(
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
):
pass
def __new__(cls, *args, **kwargs):
@ -340,6 +338,7 @@ def range_dynamic(*args, **kwargs):
def range_constexpr(*args):
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
# =============================================================================
# If expressions
# =============================================================================
@ -405,7 +404,7 @@ def assert_executor(test, msg=None):
else:
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please replace with runtime assert."
suggestion="Please replace with runtime assert.",
)
@ -413,10 +412,11 @@ def bool_cast(value):
if executor._is_dynamic_expression(value):
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please explicitly convert to boolean with expressions like comparision."
suggestion="Please explicitly convert to boolean with expressions like comparision.",
)
return bool(value)
def compare_executor(left, comparators, ops):
"""
Executes comparison operations with a left operand and a list of comparators.
@ -470,6 +470,19 @@ def all_executor(iterable):
# =============================================================================
# Control flow checks
# =============================================================================
class DSLOptimizationWarning(Warning):
"""
This warning is used to warn the user about the optimization related issues in DSL.
"""
def __init__(self, message):
self.message = message
super().__init__()
def __str__(self):
return self.message
def range_value_check(*args):
"""
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
@ -495,7 +508,7 @@ def range_value_check(*args):
if range_length >= 64:
warnings.warn(
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
category=UserWarning,
category=DSLOptimizationWarning,
stacklevel=2,
)
@ -519,7 +532,50 @@ def range_perf_warning(filename, lineno, *args):
"This loop is no longer unrolled and may cause performance regression. "
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
),
category=UserWarning,
category=DSLOptimizationWarning,
filename=filename,
lineno=lineno,
)
@lru_cache(maxsize=1)
def _get_self_module():
"""
This function is used to get the owning module of this function.
"""
return inspect.getmodule(_get_self_module)
def cf_symbol_check(symbol):
"""
Check if the symbol is control flow symbol from current module.
"""
failed = False
name = symbol.__name__
self_module = _get_self_module()
if inspect.ismodule(symbol):
name = "range"
if not self_module.__name__.startswith(symbol.__name__):
failed = True
else:
owning_module = inspect.getmodule(symbol)
if owning_module != self_module:
failed = True
if failed:
raise DSLRuntimeError(
f"Incorrect {symbol.__name__} is used.",
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
)
def redirect_builtin_function(fcn):
"""
This function is used to redirect built-in function call
to the function defined in DSL package.
"""
# Only redirect if it's a built-in
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
return executor._builtin_redirector(fcn)
return fcn

File diff suppressed because it is too large Load Diff

View File

@ -139,8 +139,7 @@ def dump_cache_to_path(
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
):
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
if not os.path.exists(path):
os.makedirs(path)
os.makedirs(path, exist_ok=True)
original_path = os.getcwd()
try:
os.chdir(path)

View File

@ -205,6 +205,8 @@ class CompileOptions:
self._parser.add_argument(
"--enable-device-assertions", action="store_true", default=False
)
self._parser.add_argument("--link-libraries", type=str, default="")
try:
self._options = self._parser.parse_args(options.split())
except SystemExit as e:

View File

@ -32,13 +32,14 @@ import hashlib
from functools import lru_cache, wraps
from collections import namedtuple
from abc import ABC, abstractmethod
from typing import Any, Union, Tuple, get_origin, get_args
from types import FunctionType
from typing import Any, Union, Tuple, get_origin, get_args, List
from types import FunctionType, SimpleNamespace
import warnings
from . import typing as t
from .env_manager import EnvironmentVarManager
from .compiler import CompileOptions
from .ast_helpers import DSLOptimizationWarning
# =============================================================================
# CUDA Python
@ -56,7 +57,7 @@ from .utils.timer import timer
from .utils.logger import setup_log, log
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
from .runtime.tensor_descriptor import TensorDescriptor
from .ast_preprocessor import DSLPreprocessor
from .common import *
from .typing import (
@ -73,12 +74,6 @@ from .._mlir import runtime as rt
from .._mlir.extras import types as T
from .._mlir.dialects import arith, math, func
# =============================================================================
# cutlass.dlpack_runtime
# =============================================================================
from .runtime.dlpack_runtime import dlpack_to_tensor_desc, mark_layout_dynamic
# =============================================================================
# Global Variables
# =============================================================================
@ -177,6 +172,7 @@ def is_dynamic_expression(value):
return True
return False
def extract_mlir_values(obj):
"""
Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
@ -186,6 +182,10 @@ def extract_mlir_values(obj):
res = obj.__extract_mlir_values__()
elif isinstance(obj, (tuple, list)):
res = sum((extract_mlir_values(x) for x in obj), [])
elif isinstance(obj, SimpleNamespace):
res = []
for k, v in obj.__dict__.items():
res.extend(extract_mlir_values(v))
# Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
elif isinstance(obj, set):
raise DSLRuntimeError(
@ -215,6 +215,13 @@ def new_from_mlir_values(obj, values):
values = values[n_items:]
obj_ty = type(obj)
return obj_ty(res)
elif isinstance(obj, SimpleNamespace):
res = SimpleNamespace()
for k, v in obj.__dict__.items():
n_items = len(get_mlir_types(v))
res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
values = values[n_items:]
return res
elif isinstance(obj, set):
raise DSLRuntimeError(
"Sets are not supported in new_from_mlir_values to ensure order preservation",
@ -249,8 +256,6 @@ class DSLCallable:
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
get_arg_spec(): Returns the argument specification of the function.
get_signature(): Returns the signature of the function.
"""
def __init__(self, func):
@ -266,23 +271,23 @@ class DSLCallable:
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __signature__(self):
return inspect.signature(self.__func__)
@property
def __name__(self):
return self.__func__.__name__
def get_arg_spec(self):
return inspect.getfullargspec(self.__func__)
def get_signature(self):
return inspect.signature(self.__func__)
class BaseDSL:
gpu_module = None
def __init__(
self,
*,
name: str,
dsl_package_name: List[str],
compiler_provider: Any,
pass_sm_arch_name: str,
device_compilation_only=False,
@ -293,6 +298,7 @@ class BaseDSL:
Parameters:
- name (str): Name of DSL, used for environment variables and logging.
- package_name (str): Name of the package, used for the preprocessor.
- compiler_provider (MLIR dialect): Provider for compiler.
- pass_sm_arch_name (str): The keyword name of the SM.
- device_compilation_only (bool) : Only device code, and call it via cuda driver
@ -330,6 +336,9 @@ class BaseDSL:
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
# set warning
if not self.envar.enable_optimization_warnings:
# By default, optimization warnings are disabled
warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
if self.envar.warnings_as_errors:
warnings.filterwarnings("error")
if self.envar.warnings_ignore:
@ -355,7 +364,7 @@ class BaseDSL:
self.compile_options = CompileOptions()
if preprocess:
self.preprocessor = DSLPreprocessor()
self.preprocessor = DSLPreprocessor(dsl_package_name)
log().info(f"Initializing {name} DSL")
log().debug(f"Logger initialized for {self.name}")
@ -656,7 +665,7 @@ class BaseDSL:
return ir_args, ir_kwargs
@abstractmethod
def _generate_mlir_type_for_tensor_descriptor(self, tensor: TensorDescriptor):
def _generate_mlir_type_for_tensor_descriptor(self, tensor):
"""
Generate MLIR type for the tensor descriptor.
"""
@ -671,13 +680,6 @@ class BaseDSL:
"""
pass
@abstractmethod
def _get_module_globals(self):
"""
Get the module's globals.
"""
pass
def _get_globals(self):
"""
Combines global and local variables from the current context and the
@ -690,43 +692,21 @@ class BaseDSL:
AST preprocessor generates a new python code, so the resulting globals
dictionary is used to execute the python code.
"""
all_globals = self._get_module_globals().copy()
all_globals = {}
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
return all_globals
@abstractmethod
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
pass
@abstractmethod
def _handle_tensor_descriptor(
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
) -> TensorDescriptor:
if self._is_tensor_descriptor(maybe_tensor):
tensor = (
maybe_tensor
if isinstance(maybe_tensor, TensorDescriptor)
else TensorDescriptor(maybe_tensor)
)
if need_gpu_memory and not tensor.is_in_device:
log().info(
"FAIL name=[%s] tensor=[%s] in_gpu=[%s]",
arg_name,
tensor,
tensor.is_in_device,
)
raise DSLRuntimeError(
f'Tensor "{arg_name}" is tensor "{tensor}" '
"is not in the GPU memory. "
)
return tensor
raise DSLRuntimeError(
f"Argument {arg_name} could not be transformed into a TensorDescriptor."
)
) -> Any:
pass
def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
"""
@ -882,10 +862,11 @@ class BaseDSL:
cluster: list = None
grid: list = field(default_factory=lambda: [1, 1, 1])
block: list = field(default_factory=lambda: [1, 1, 1])
smem: int = 0
smem: int = None
async_deps: list = field(default_factory=list)
has_cluster: bool = False
min_blocks_per_mp: int = 0
auto_smem: bool = False
def __post_init__(self):
if len(self.grid) != 3:
@ -893,6 +874,10 @@ class BaseDSL:
if len(self.block) != 3:
raise DSLRuntimeError(f"Expect 3d block!")
if self.smem is None:
self.smem = 0
self.auto_smem = True
self.has_cluster = self.cluster is not None
if self.cluster is None:
self.cluster = [None, None, None]
@ -1116,8 +1101,6 @@ class BaseDSL:
try:
result = funcBody(*ir_args, **ir_kwargs)
func.ReturnOp([])
except DSLAstPreprocessorError as pp_error:
raise pp_error
except NameError as name_error:
raise DSLRuntimeError(
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
@ -1127,11 +1110,6 @@ class BaseDSL:
except DSLRuntimeError as dsl_error:
# Throw it's already a DSL error
raise dsl_error
except Exception as general_e:
# Transform internal error to a DSL error
raise DSLRuntimeError(
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥"
) from general_e
return module, result
# Build IR module
@ -1328,10 +1306,8 @@ class BaseDSL:
raise DSLRuntimeError("Function body is not set.")
# Pass the actual function object to inspect.signature to get the signature.
if isinstance(self.funcBody, DSLCallable):
sig = self.funcBody.get_signature()
else:
sig = inspect.signature(self.funcBody)
sig = inspect.signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
@ -1382,10 +1358,7 @@ class BaseDSL:
# Check the number of arguments
sig = self._check_arg_count(*args, **kwargs)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
args_spec = inspect.getfullargspec(funcBody)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
@ -1447,7 +1420,7 @@ class BaseDSL:
return cuda_helpers.stream_create()
def _execute_cuda(
self, fname_cubin, kernel_name, grid_size, block_size, stream=None
self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
):
"""
Executes a specified CUDA kernel from a cubin file, handling module loading,
@ -1471,7 +1444,7 @@ class BaseDSL:
grid_size,
block_size,
stream,
smem_size=16000,
smem_size=smem_size,
kernel_args=self.exe_args,
)
@ -1480,7 +1453,13 @@ class BaseDSL:
cuda_helpers.stream_sync(stream)
def _execute_by_cuda_driver(
self, kernel_generator, generate_cubin, grid_size, block_size, stream=None
self,
kernel_generator,
generate_cubin,
grid_size,
block_size,
smem_size,
stream=None,
):
"""
This function builds IR and execute the module using cuda driver.
@ -1511,10 +1490,9 @@ class BaseDSL:
fname_cubin = generate_cubin(module, kernel_name)
# Execute a cuda kernel from cubin
if block_size is None:
# The TileIR driver should set this automatically.
block_size = self.block_size
self._execute_cuda(fname_cubin, kernel_name, grid_size, block_size, stream)
self._execute_cuda(
fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
)
return ret
@ -1587,10 +1565,7 @@ class BaseDSL:
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
kernel_name = funcBody.__name__
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
args_spec = inspect.getfullargspec(funcBody)
self.funcBody = funcBody
# Give each kernel a unique name. (The same kernel may be

View File

@ -58,6 +58,11 @@ def get_int_env_var(var_name, default_value=0):
return int(value) if value and value.isdigit() else default_value
@lru_cache(maxsize=None)
def has_env_var(var_name):
return os.getenv(var_name) is not None
def detect_gpu_arch(prefix):
"""
Attempts to detect the machine's GPU architecture.
@ -256,6 +261,7 @@ class EnvironmentVarManager:
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
@ -267,7 +273,6 @@ class EnvironmentVarManager:
self.prefix = prefix # change if needed
# Printing options
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
self.print_after_preprocessor = get_bool_env_var(
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
)
@ -275,15 +280,29 @@ class EnvironmentVarManager:
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
# File options
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
# Logging options
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
# Other options
if (
has_env_var(f"{prefix}_LOG_LEVEL")
and not self.log_to_console
and not self.log_to_file
):
log().warning(
f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
)
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
# Other options
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
self.warnings_as_errors = get_bool_env_var(
f"{prefix}_WARNINGS_AS_ERRORS", False
)
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
self.enable_optimization_warnings = get_bool_env_var(
f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
)
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
self.disable_file_caching = get_bool_env_var(
f"{prefix}_DISABLE_FILE_CACHING", False

View File

@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for
the DSL.
"""
from . import device_tensor
from . import dlpack_types
from . import cuda
from . import tensor_descriptor
from . import jit_arg_adapters
__all__ = [
"device_tensor",
"dlpack_types",
"cuda",
"tensor_descriptor",
"jit_arg_adapters",
]

View File

@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name):
return kernel
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
"""
Launches the CUDA kernel.
"""

View File

@ -183,6 +183,13 @@ class TensorDescriptor:
"""
return self.device_type == _dpack.DLDeviceType.kDLGPU
@staticmethod
def is_compatible(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor or can be converted to one."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
def from_tensor(tensor) -> TensorDescriptor:
"""Create a TensorDescriptor from a tensor object."""
@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor:
def to_tensor(tensor_descriptor: TensorDescriptor):
"""Return tensor object from tensor descriptor."""
return tensor_descriptor.tensor
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)