v4.1 release

This commit is contained in:
Junkai-Wu
2025-07-03 20:07:53 +08:00
committed by GitHub
parent b995f93317
commit a1aaf2300a
155 changed files with 18407 additions and 6068 deletions

View File

@ -164,16 +164,17 @@ def _mlir_type_to_numpy_type(type):
def is_dynamic_expression(value):
"""
Check if the value is an MLIR's SSA value.
Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
"""
# Case 1: If the value has MLIR's SSA value, return True
# Case 2: If the value supports __extract_mlir_values__ then it's possible to get SSA value
return (
isinstance(value, ir.Value)
or hasattr(value, "__extract_mlir_values__")
or len(extract_mlir_values(value)) > 0
)
if isinstance(value, (tuple, list)):
for x in value:
if is_dynamic_expression(x):
return True
elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
value, "__extract_mlir_values__"
):
return True
return False
def extract_mlir_values(obj):
"""
@ -726,6 +727,7 @@ class BaseDSL:
)
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
jit_adapted_args = []
default_attr = ir.DictAttr.get({})
input_args = [*args, *kwargs.values()]
@ -759,7 +761,9 @@ class BaseDSL:
# If not any known type, try JIT argument adapter
# to convert the argument
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
arg = adapter(arg) if adapter else arg
if adapter:
arg = adapter(arg)
jit_adapted_args.append(arg)
if is_host:
jit_exec_arg.extend(get_c_pointers(arg))
@ -798,14 +802,14 @@ class BaseDSL:
jit_arg_types.extend(jit_arg_type)
jit_arg_attrs.extend(jit_arg_attr)
return jit_exec_args, jit_arg_types, jit_arg_attrs
return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
def generate_mlir_function_types(
self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
):
"""Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
exe_args, types, _ = self._generate_jit_func_args(
exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
func, function_name, input_args, kwargs, args_spec, is_host=True
)
@ -816,7 +820,7 @@ class BaseDSL:
types
), "expects the same number of arguments and function parameters"
return exe_args, types
return exe_args, types, adapted_args
@dataclass
class LaunchConfig:
@ -1158,7 +1162,7 @@ class BaseDSL:
"""Generate MLIR module and compile iself.T_provider."""
with ir.Context(), ir.Location.unknown():
# Convert input arguments to MLIR arguments
exe_args, func_types = self.generate_mlir_function_types(
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
funcBody, function_name, args, kwargs, args_spec
)
@ -1476,7 +1480,7 @@ class BaseDSL:
if self.device_compilation_only:
return kernel_operands, kernel_arg_types, kernel_arg_attrs
kernel_operands, kernel_arg_types, kernel_arg_attrs = (
kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
self._generate_jit_func_args(
kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
)
@ -1586,12 +1590,14 @@ class BaseDSL:
if self.device_compilation_only:
log().debug("Generating cuda-python arguments")
# Convert input arguments to MLIR arguments
self.exe_args, kernel_types = self.generate_mlir_function_types(
funcBody,
kernel_name,
canonicalized_args,
canonicalized_kwargs,
args_spec,
self.exe_args, kernel_types, _ = (
self.generate_mlir_function_types(
funcBody,
kernel_name,
canonicalized_args,
canonicalized_kwargs,
args_spec,
)
)
helper = kernelGenHelper()