v4.0 update. (#2371)
This commit is contained in:
@ -566,7 +566,9 @@ class BaseDSL:
|
||||
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
|
||||
|
||||
# Implicit cast to NumericMeta
|
||||
if isinstance(arg_spec, t.NumericMeta):
|
||||
if isinstance(arg_spec, t.NumericMeta) and not isinstance(
|
||||
arg, arg_spec
|
||||
):
|
||||
arg = t.cast(arg, arg_spec)
|
||||
|
||||
ir_arg, iv_block_args = (
|
||||
@ -589,15 +591,17 @@ class BaseDSL:
|
||||
self.log_additions(ir_arg)
|
||||
ir_args.extend(ir_arg)
|
||||
|
||||
return ir_args
|
||||
return ir_args, iv_block_args
|
||||
|
||||
fop_args = list(fop.regions[0].blocks[0].arguments)
|
||||
ir_args = gen_exec_args(args, args_spec.args, args_spec.annotations, fop_args)
|
||||
ir_kwargs = gen_exec_args(
|
||||
ir_args, iv_block_args = gen_exec_args(
|
||||
args, args_spec.args, args_spec.annotations, fop_args
|
||||
)
|
||||
ir_kwargs, _ = gen_exec_args(
|
||||
[kwargs[arg] for arg in args_spec.kwonlyargs],
|
||||
args_spec.kwonlyargs,
|
||||
args_spec.annotations,
|
||||
fop_args[len(ir_args) :],
|
||||
fop_args[iv_block_args:],
|
||||
)
|
||||
ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
|
||||
|
||||
@ -716,8 +720,10 @@ class BaseDSL:
|
||||
|
||||
assert len(args) == len(args_spec.args) and len(kwargs) == len(
|
||||
args_spec.kwonlyargs
|
||||
), f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
|
||||
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
|
||||
), (
|
||||
f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
|
||||
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
|
||||
)
|
||||
|
||||
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
|
||||
default_attr = ir.DictAttr.get({})
|
||||
@ -729,7 +735,7 @@ class BaseDSL:
|
||||
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
|
||||
|
||||
# Implicitly convert into Numeric type if possible
|
||||
if isinstance(spec_ty, t.NumericMeta):
|
||||
if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
|
||||
arg = t.cast(arg, spec_ty)
|
||||
|
||||
# Type safety check
|
||||
|
||||
Reference in New Issue
Block a user