v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -566,7 +566,9 @@ class BaseDSL:
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
# Implicit cast to NumericMeta
if isinstance(arg_spec, t.NumericMeta):
if isinstance(arg_spec, t.NumericMeta) and not isinstance(
arg, arg_spec
):
arg = t.cast(arg, arg_spec)
ir_arg, iv_block_args = (
@ -589,15 +591,17 @@ class BaseDSL:
self.log_additions(ir_arg)
ir_args.extend(ir_arg)
return ir_args
return ir_args, iv_block_args
fop_args = list(fop.regions[0].blocks[0].arguments)
ir_args = gen_exec_args(args, args_spec.args, args_spec.annotations, fop_args)
ir_kwargs = gen_exec_args(
ir_args, iv_block_args = gen_exec_args(
args, args_spec.args, args_spec.annotations, fop_args
)
ir_kwargs, _ = gen_exec_args(
[kwargs[arg] for arg in args_spec.kwonlyargs],
args_spec.kwonlyargs,
args_spec.annotations,
fop_args[len(ir_args) :],
fop_args[iv_block_args:],
)
ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
@ -716,8 +720,10 @@ class BaseDSL:
assert len(args) == len(args_spec.args) and len(kwargs) == len(
args_spec.kwonlyargs
), f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
), (
f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
)
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
default_attr = ir.DictAttr.get({})
@ -729,7 +735,7 @@ class BaseDSL:
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
# Implicitly convert into Numeric type if possible
if isinstance(spec_ty, t.NumericMeta):
if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
arg = t.cast(arg, spec_ty)
# Type safety check