v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -38,6 +38,7 @@ import warnings
from . import typing as t
from .env_manager import EnvironmentVarManager
from .compiler import CompileOptions
# =============================================================================
# CUDA Python
@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
return obj
class DSLCallable:
"""
Wrapper class for a callable object used within the DSL.
DSLCallable is designed to wrap a function and provide additional
introspection utilities such as retrieving the argument specification
and signature. It ensures that the wrapped function can only be called
once, after which the reference to the function is cleared to prevent
further invocations. This is useful in scenarios where a function should
only be executed a single time within the DSL's execution model.
Attributes:
func (callable): The function to be wrapped and managed.
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
get_arg_spec(): Returns the argument specification of the function.
get_signature(): Returns the signature of the function.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
ret = self.__func__(*args, **kwargs)
self.func = None
return ret
@property
def __func__(self):
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __name__(self):
return self.__func__.__name__
def get_arg_spec(self):
return inspect.getfullargspec(self.__func__)
def get_signature(self):
return inspect.signature(self.__func__)
class BaseDSL:
gpu_module = None
@ -306,6 +351,8 @@ class BaseDSL:
self.kernel_symbols = []
# used to generate unique name for gpu.launch
self.launch_inner_count = 0
# initialize default compile options
self.compile_options = CompileOptions()
if preprocess:
self.preprocessor = DSLPreprocessor()
@ -392,26 +439,24 @@ class BaseDSL:
if hasattr(func, "_transformed_ast"):
# If the function ptr is already materialized, use the existing one
func._dsl_object.frame = func._decorator_frame
if func._transformed_ast is None:
func._transformed_ast = func._dsl_object.run_preprocessor(func)
if func._transformed_ast is None:
del func._decorator_frame
del func._transformed_ast
func._dsl_object.frame = None
return func
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
fcn_ptr = func._dsl_object.get_function_ptr(func)
# If the function is decorated, de-decorate it
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
return fcn_ptr
func._dsl_object.frame = None
return DSLCallable(fcn_ptr)
return func
def jit_runner(self, frame, executor, *dargs, **dkwargs):
def jit_runner(self, executor, frame, *dargs, **dkwargs):
"""
Decorator to mark a function for JIT compilation.
"""
# Set the frame, that can be used AST preprocessor
self.frame = frame
log().info("jit_runner")
def jit_runner_decorator(func):
@ -444,7 +489,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
@classmethod
def kernel(cls, *dargs, **dkwargs):
@ -454,7 +499,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
@abstractmethod
def _kernel_helper(self, func, *args, **kwargs):
@ -627,6 +672,12 @@ class BaseDSL:
pass
@abstractmethod
def _get_module_globals(self):
"""
Get the module's globals.
"""
pass
def _get_globals(self):
"""
Combines global and local variables from the current context and the
@ -639,7 +690,11 @@ class BaseDSL:
AST preprocessor generates a new python code, so the resulting globals
dictionary is used to execute the python code.
"""
pass
all_globals = self._get_module_globals().copy()
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
return all_globals
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return isinstance(
@ -881,20 +936,15 @@ class BaseDSL:
Get python location information and generate MLIR location
"""
frame = self.frame
if frame is None:
print("Frame is None")
if self.frame is None:
log().debug("Frame is None")
return None
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
file_loc = ir.Location.file(
self.frame.f_code.co_filename, self.frame.f_lineno, 0
)
def print_all_frames():
for i, frame in enumerate(inspect.stack()):
print(
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
)
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
return loc
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
@ -992,6 +1042,8 @@ class BaseDSL:
for attr, value in self.envar.__dict__.items():
if value is not None:
s.write(str(value).encode())
# Add compile options to the hash
s.write(self.compile_options.to_str().encode())
module_hash = self.get_version().copy()
module_hash.update(s.getvalue())
module_hash = module_hash.hexdigest()
@ -1145,6 +1197,8 @@ class BaseDSL:
self.launch_inner_count = 0
# reset num_kernels to 0 for next compilation.
self.num_kernels = 0
# reset the compile options after the compilation is done.
self.compile_options = CompileOptions()
def generate_mlir(
self,
@ -1226,9 +1280,11 @@ class BaseDSL:
return transformed_ast
return None
def get_function_ptr(self, original_function, transformed_ast):
def get_function_ptr(self, original_function):
file_name = inspect.getsourcefile(original_function)
code_object = compile(transformed_ast, filename=file_name, mode="exec")
code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
)
return self.preprocessor.exec(
original_function.__name__,
original_function,
@ -1236,10 +1292,6 @@ class BaseDSL:
self._get_globals(),
)
@lru_cache(maxsize=None)
def _get_function_signature(self, func):
return inspect.signature(func)
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
"""
Binds provided arguments to a function's signature and applies default values.
@ -1260,12 +1312,11 @@ class BaseDSL:
)
return bound_args
def _canonicalize_args(self, *args, **kwargs):
def _canonicalize_args(self, sig, *args, **kwargs):
"""
Canonicalize the input arguments so that returned args only contain
positional arguments and kwargs only contain keyword arguments.
"""
sig = self._get_function_signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
canonicalized_args = bound_args.args
@ -1276,8 +1327,11 @@ class BaseDSL:
if not self.funcBody:
raise DSLRuntimeError("Function body is not set.")
# Pass the actual function object to _get_function_signature.
sig = self._get_function_signature(self.funcBody)
# Pass the actual function object to inspect.signature to get the signature.
if isinstance(self.funcBody, DSLCallable):
sig = self.funcBody.get_signature()
else:
sig = inspect.signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
@ -1292,6 +1346,8 @@ class BaseDSL:
f"Missing required argument in `{function_name}`: '{param.name}'"
)
return sig
def _func(self, funcBody, *args, **kwargs):
"""Decorator for MLIR functions.
It cuts the boilerplate code, does the following:
@ -1324,13 +1380,16 @@ class BaseDSL:
self.print_warning("Cache is disabled as user wants to compile only.")
# Check the number of arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
# Simple name mangling
@ -1528,7 +1587,10 @@ class BaseDSL:
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
kernel_name = funcBody.__name__
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
self.funcBody = funcBody
# Give each kernel a unique name. (The same kernel may be
@ -1568,11 +1630,11 @@ class BaseDSL:
), "kernelGenHelper should be explicitly specified!"
# check arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
kernel_operands, kernel_types, kernel_arg_attrs = (