v4.1 release update v2. (#2481)
This commit is contained in:
@ -159,7 +159,11 @@ class CutlassBaseDSL(BaseDSL):
|
||||
pipeline = super()._get_pipeline(pipeline)
|
||||
if pipeline == None:
|
||||
# cubin format is required to be cubin as we launch cuda module at python level.
|
||||
return "builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3})"
|
||||
return (
|
||||
"builtin.module(cute-to-nvvm{cubin-format=bin "
|
||||
+ self.compile_options.to_str()
|
||||
+ "})"
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
@ -294,13 +298,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
|
||||
)
|
||||
|
||||
def _get_globals(self):
|
||||
caller_globals = self.frame.f_globals
|
||||
caller_locals = self.frame.f_locals
|
||||
all_globals = globals().copy()
|
||||
all_globals.update(caller_globals)
|
||||
all_globals.update(caller_locals)
|
||||
return all_globals
|
||||
def _get_module_globals(self):
|
||||
return globals()
|
||||
|
||||
def _preprocess_launch_config_args(self, args, kwargs):
|
||||
"""Helper to preprocess args and kwargs for LaunchConfig"""
|
||||
@ -459,7 +458,10 @@ class KernelLauncher:
|
||||
|
||||
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
|
||||
# Get function signature
|
||||
sig = inspect.signature(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
sig = funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(funcBody)
|
||||
|
||||
# func_args and func_kwargs should match funcBody's signature,
|
||||
# no extra or missing arguments.
|
||||
@ -485,6 +487,7 @@ class KernelLauncher:
|
||||
|
||||
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
|
||||
self.dsl.kernel_symbols.append(name)
|
||||
self.dsl.frame = None
|
||||
return ret.launch_op_ret
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -537,14 +540,18 @@ def pack_from_irvalue(
|
||||
mixed_values[idx] = obj
|
||||
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
|
||||
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
|
||||
elif isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
try:
|
||||
if isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
if len(chunk) == 1:
|
||||
try:
|
||||
mixed_values[idx] = t.as_numeric(chunk[0])
|
||||
except DSLRuntimeError as e:
|
||||
mixed_values[idx] = chunk[0]
|
||||
except ValueError:
|
||||
# Suppress the conversion error and try new_from_mlir_values below
|
||||
pass
|
||||
|
||||
if mixed_values[idx] is None:
|
||||
mixed_values[idx] = new_from_mlir_values(obj, chunk)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
|
||||
Reference in New Issue
Block a user