v4.1 release update v2. (#2481)
This commit is contained in:
@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import argparse
|
||||
from .common import DSLRuntimeError
|
||||
from .utils.logger import log
|
||||
|
||||
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
@ -182,7 +184,67 @@ class Compiler:
|
||||
return self.jit(module, opt_level, shared_libs)
|
||||
|
||||
|
||||
class CompileOptions:
|
||||
def __init__(self, options: str = ""):
|
||||
"""
|
||||
This class encapsulates all compilation options relevant to function compilation.
|
||||
It provides a convenient way to manage and pass compilation options,
|
||||
particularly for controlling compilation settings.
|
||||
By centralizing these options, it ensures consistent and flexible configuration of
|
||||
compilation parameters such as optimization level, debugging control, etc.
|
||||
|
||||
:param options: The options for the function. Will be parsed by argparse.
|
||||
:type options: str
|
||||
"""
|
||||
if not isinstance(options, str):
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compilation `options`: {options}, it should be a string"
|
||||
)
|
||||
self._parser = argparse.ArgumentParser()
|
||||
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
|
||||
self._parser.add_argument(
|
||||
"--enable-device-assertions", action="store_true", default=False
|
||||
)
|
||||
try:
|
||||
self._options = self._parser.parse_args(options.split())
|
||||
except SystemExit as e:
|
||||
# catch argparse error and raise as DSLRuntimeError
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compile options: '{options}'. Please check the option values and format."
|
||||
)
|
||||
log().info("`cute.compile` CompileOptions: options=" + options)
|
||||
|
||||
def to_str(self):
|
||||
"""
|
||||
Generate a string representation of all compilation options
|
||||
which will be used in pipeline options.
|
||||
"""
|
||||
option_strings = []
|
||||
for key, value in vars(self._options).items():
|
||||
hyphen_key = key.replace("_", "-")
|
||||
if isinstance(value, bool):
|
||||
formatted_value = "true" if value else "false"
|
||||
else:
|
||||
formatted_value = str(value)
|
||||
option_strings.append(f"{hyphen_key}={formatted_value}")
|
||||
|
||||
return " ".join(option_strings)
|
||||
|
||||
|
||||
def compile(func, *args, **kwargs):
|
||||
"""
|
||||
This function is used to compile a `cute.jit` decorated function.
|
||||
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
|
||||
|
||||
:param func: The function to compile. It can be a regular function, a method or a class instance.
|
||||
:param args: The arguments to pass to the function.
|
||||
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
|
||||
`opt_level` to control the compilation flags.
|
||||
|
||||
:return: The jit executor.
|
||||
|
||||
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
|
||||
"""
|
||||
if func is None:
|
||||
raise DSLRuntimeError("Function is not set or invalid.")
|
||||
|
||||
@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
|
||||
# process compile options, extract the options and remove them from the kwargs
|
||||
options = kwargs.pop("options", "")
|
||||
func._dsl_object.compile_options = CompileOptions(options)
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user