v4.1 release
This commit is contained in:
@ -164,16 +164,17 @@ def _mlir_type_to_numpy_type(type):
|
||||
|
||||
def is_dynamic_expression(value):
|
||||
"""
|
||||
Check if the value is an MLIR's SSA value.
|
||||
Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
|
||||
"""
|
||||
# Case 1: If the value has MLIR's SSA value, return True
|
||||
# Case 2: If the value supports __extract_mlir_values__ then it's possible to get SSA value
|
||||
return (
|
||||
isinstance(value, ir.Value)
|
||||
or hasattr(value, "__extract_mlir_values__")
|
||||
or len(extract_mlir_values(value)) > 0
|
||||
)
|
||||
|
||||
if isinstance(value, (tuple, list)):
|
||||
for x in value:
|
||||
if is_dynamic_expression(x):
|
||||
return True
|
||||
elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
|
||||
value, "__extract_mlir_values__"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_mlir_values(obj):
|
||||
"""
|
||||
@ -726,6 +727,7 @@ class BaseDSL:
|
||||
)
|
||||
|
||||
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
|
||||
jit_adapted_args = []
|
||||
default_attr = ir.DictAttr.get({})
|
||||
|
||||
input_args = [*args, *kwargs.values()]
|
||||
@ -759,7 +761,9 @@ class BaseDSL:
|
||||
# If not any known type, try JIT argument adapter
|
||||
# to convert the argument
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
arg = adapter(arg) if adapter else arg
|
||||
if adapter:
|
||||
arg = adapter(arg)
|
||||
jit_adapted_args.append(arg)
|
||||
|
||||
if is_host:
|
||||
jit_exec_arg.extend(get_c_pointers(arg))
|
||||
@ -798,14 +802,14 @@ class BaseDSL:
|
||||
jit_arg_types.extend(jit_arg_type)
|
||||
jit_arg_attrs.extend(jit_arg_attr)
|
||||
|
||||
return jit_exec_args, jit_arg_types, jit_arg_attrs
|
||||
return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
|
||||
|
||||
def generate_mlir_function_types(
|
||||
self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
|
||||
):
|
||||
"""Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
|
||||
|
||||
exe_args, types, _ = self._generate_jit_func_args(
|
||||
exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
|
||||
func, function_name, input_args, kwargs, args_spec, is_host=True
|
||||
)
|
||||
|
||||
@ -816,7 +820,7 @@ class BaseDSL:
|
||||
types
|
||||
), "expects the same number of arguments and function parameters"
|
||||
|
||||
return exe_args, types
|
||||
return exe_args, types, adapted_args
|
||||
|
||||
@dataclass
|
||||
class LaunchConfig:
|
||||
@ -1158,7 +1162,7 @@ class BaseDSL:
|
||||
"""Generate MLIR module and compile iself.T_provider."""
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
# Convert input arguments to MLIR arguments
|
||||
exe_args, func_types = self.generate_mlir_function_types(
|
||||
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
|
||||
funcBody, function_name, args, kwargs, args_spec
|
||||
)
|
||||
|
||||
@ -1476,7 +1480,7 @@ class BaseDSL:
|
||||
if self.device_compilation_only:
|
||||
return kernel_operands, kernel_arg_types, kernel_arg_attrs
|
||||
|
||||
kernel_operands, kernel_arg_types, kernel_arg_attrs = (
|
||||
kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
|
||||
self._generate_jit_func_args(
|
||||
kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
|
||||
)
|
||||
@ -1586,12 +1590,14 @@ class BaseDSL:
|
||||
if self.device_compilation_only:
|
||||
log().debug("Generating cuda-python arguments")
|
||||
# Convert input arguments to MLIR arguments
|
||||
self.exe_args, kernel_types = self.generate_mlir_function_types(
|
||||
funcBody,
|
||||
kernel_name,
|
||||
canonicalized_args,
|
||||
canonicalized_kwargs,
|
||||
args_spec,
|
||||
self.exe_args, kernel_types, _ = (
|
||||
self.generate_mlir_function_types(
|
||||
funcBody,
|
||||
kernel_name,
|
||||
canonicalized_args,
|
||||
canonicalized_kwargs,
|
||||
args_spec,
|
||||
)
|
||||
)
|
||||
|
||||
helper = kernelGenHelper()
|
||||
|
||||
Reference in New Issue
Block a user