v4.1 release update v2. (#2481)
This commit is contained in:
@ -38,6 +38,7 @@ import warnings
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
from .compiler import CompileOptions
|
||||
|
||||
# =============================================================================
|
||||
# CUDA Python
|
||||
@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
|
||||
return obj
|
||||
|
||||
|
||||
class DSLCallable:
|
||||
"""
|
||||
Wrapper class for a callable object used within the DSL.
|
||||
|
||||
DSLCallable is designed to wrap a function and provide additional
|
||||
introspection utilities such as retrieving the argument specification
|
||||
and signature. It ensures that the wrapped function can only be called
|
||||
once, after which the reference to the function is cleared to prevent
|
||||
further invocations. This is useful in scenarios where a function should
|
||||
only be executed a single time within the DSL's execution model.
|
||||
|
||||
Attributes:
|
||||
func (callable): The function to be wrapped and managed.
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
get_arg_spec(): Returns the argument specification of the function.
|
||||
get_signature(): Returns the signature of the function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
ret = self.__func__(*args, **kwargs)
|
||||
self.func = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def __func__(self):
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.__func__.__name__
|
||||
|
||||
def get_arg_spec(self):
|
||||
return inspect.getfullargspec(self.__func__)
|
||||
|
||||
def get_signature(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
gpu_module = None
|
||||
|
||||
@ -306,6 +351,8 @@ class BaseDSL:
|
||||
self.kernel_symbols = []
|
||||
# used to generate unique name for gpu.launch
|
||||
self.launch_inner_count = 0
|
||||
# initialize default compile options
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
if preprocess:
|
||||
self.preprocessor = DSLPreprocessor()
|
||||
@ -392,26 +439,24 @@ class BaseDSL:
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
# If the function ptr is already materialized, use the existing one
|
||||
func._dsl_object.frame = func._decorator_frame
|
||||
|
||||
if func._transformed_ast is None:
|
||||
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
if func._transformed_ast is None:
|
||||
del func._decorator_frame
|
||||
del func._transformed_ast
|
||||
func._dsl_object.frame = None
|
||||
return func
|
||||
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
||||
# If the function is decorated, de-decorate it
|
||||
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
||||
return fcn_ptr
|
||||
func._dsl_object.frame = None
|
||||
return DSLCallable(fcn_ptr)
|
||||
return func
|
||||
|
||||
def jit_runner(self, frame, executor, *dargs, **dkwargs):
|
||||
def jit_runner(self, executor, frame, *dargs, **dkwargs):
|
||||
"""
|
||||
Decorator to mark a function for JIT compilation.
|
||||
"""
|
||||
# Set the frame, that can be used AST preprocessor
|
||||
self.frame = frame
|
||||
log().info("jit_runner")
|
||||
|
||||
def jit_runner_decorator(func):
|
||||
@ -444,7 +489,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
|
||||
|
||||
@classmethod
|
||||
def kernel(cls, *dargs, **dkwargs):
|
||||
@ -454,7 +499,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _kernel_helper(self, func, *args, **kwargs):
|
||||
@ -627,6 +672,12 @@ class BaseDSL:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_module_globals(self):
|
||||
"""
|
||||
Get the module's globals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_globals(self):
|
||||
"""
|
||||
Combines global and local variables from the current context and the
|
||||
@ -639,7 +690,11 @@ class BaseDSL:
|
||||
AST preprocessor generates a new python code, so the resulting globals
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
pass
|
||||
all_globals = self._get_module_globals().copy()
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
return all_globals
|
||||
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return isinstance(
|
||||
@ -881,20 +936,15 @@ class BaseDSL:
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
|
||||
frame = self.frame
|
||||
if frame is None:
|
||||
print("Frame is None")
|
||||
if self.frame is None:
|
||||
log().debug("Frame is None")
|
||||
return None
|
||||
|
||||
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
|
||||
file_loc = ir.Location.file(
|
||||
self.frame.f_code.co_filename, self.frame.f_lineno, 0
|
||||
)
|
||||
|
||||
def print_all_frames():
|
||||
for i, frame in enumerate(inspect.stack()):
|
||||
print(
|
||||
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
|
||||
)
|
||||
|
||||
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
|
||||
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
|
||||
return loc
|
||||
|
||||
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
|
||||
@ -992,6 +1042,8 @@ class BaseDSL:
|
||||
for attr, value in self.envar.__dict__.items():
|
||||
if value is not None:
|
||||
s.write(str(value).encode())
|
||||
# Add compile options to the hash
|
||||
s.write(self.compile_options.to_str().encode())
|
||||
module_hash = self.get_version().copy()
|
||||
module_hash.update(s.getvalue())
|
||||
module_hash = module_hash.hexdigest()
|
||||
@ -1145,6 +1197,8 @@ class BaseDSL:
|
||||
self.launch_inner_count = 0
|
||||
# reset num_kernels to 0 for next compilation.
|
||||
self.num_kernels = 0
|
||||
# reset the compile options after the compilation is done.
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
def generate_mlir(
|
||||
self,
|
||||
@ -1226,9 +1280,11 @@ class BaseDSL:
|
||||
return transformed_ast
|
||||
return None
|
||||
|
||||
def get_function_ptr(self, original_function, transformed_ast):
|
||||
def get_function_ptr(self, original_function):
|
||||
file_name = inspect.getsourcefile(original_function)
|
||||
code_object = compile(transformed_ast, filename=file_name, mode="exec")
|
||||
code_object = compile(
|
||||
original_function._transformed_ast, filename=file_name, mode="exec"
|
||||
)
|
||||
return self.preprocessor.exec(
|
||||
original_function.__name__,
|
||||
original_function,
|
||||
@ -1236,10 +1292,6 @@ class BaseDSL:
|
||||
self._get_globals(),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_function_signature(self, func):
|
||||
return inspect.signature(func)
|
||||
|
||||
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
|
||||
"""
|
||||
Binds provided arguments to a function's signature and applies default values.
|
||||
@ -1260,12 +1312,11 @@ class BaseDSL:
|
||||
)
|
||||
return bound_args
|
||||
|
||||
def _canonicalize_args(self, *args, **kwargs):
|
||||
def _canonicalize_args(self, sig, *args, **kwargs):
|
||||
"""
|
||||
Canonicalize the input arguments so that returned args only contain
|
||||
positional arguments and kwargs only contain keyword arguments.
|
||||
"""
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
canonicalized_args = bound_args.args
|
||||
@ -1276,8 +1327,11 @@ class BaseDSL:
|
||||
if not self.funcBody:
|
||||
raise DSLRuntimeError("Function body is not set.")
|
||||
|
||||
# Pass the actual function object to _get_function_signature.
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
# Pass the actual function object to inspect.signature to get the signature.
|
||||
if isinstance(self.funcBody, DSLCallable):
|
||||
sig = self.funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
@ -1292,6 +1346,8 @@ class BaseDSL:
|
||||
f"Missing required argument in `{function_name}`: '{param.name}'"
|
||||
)
|
||||
|
||||
return sig
|
||||
|
||||
def _func(self, funcBody, *args, **kwargs):
|
||||
"""Decorator for MLIR functions.
|
||||
It cuts the boilerplate code, does the following:
|
||||
@ -1324,13 +1380,16 @@ class BaseDSL:
|
||||
self.print_warning("Cache is disabled as user wants to compile only.")
|
||||
|
||||
# Check the number of arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
# Simple name mangling
|
||||
@ -1528,7 +1587,10 @@ class BaseDSL:
|
||||
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
||||
|
||||
kernel_name = funcBody.__name__
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
self.funcBody = funcBody
|
||||
|
||||
# Give each kernel a unique name. (The same kernel may be
|
||||
@ -1568,11 +1630,11 @@ class BaseDSL:
|
||||
), "kernelGenHelper should be explicitly specified!"
|
||||
|
||||
# check arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
kernel_operands, kernel_types, kernel_arg_attrs = (
|
||||
|
||||
Reference in New Issue
Block a user