v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
import os
import sys
import inspect
import argparse
from .common import DSLRuntimeError
from .utils.logger import log
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
@ -182,7 +184,67 @@ class Compiler:
return self.jit(module, opt_level, shared_libs)
class CompileOptions:
def __init__(self, options: str = ""):
"""
This class encapsulates all compilation options relevant to function compilation.
It provides a convenient way to manage and pass compilation options,
particularly for controlling compilation settings.
By centralizing these options, it ensures consistent and flexible configuration of
compilation parameters such as optimization level, debugging control, etc.
:param options: The options for the function. Will be parsed by argparse.
:type options: str
"""
if not isinstance(options, str):
raise DSLRuntimeError(
f"Invalid compilation `options`: {options}, it should be a string"
)
self._parser = argparse.ArgumentParser()
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
self._parser.add_argument(
"--enable-device-assertions", action="store_true", default=False
)
try:
self._options = self._parser.parse_args(options.split())
except SystemExit as e:
# catch argparse error and raise as DSLRuntimeError
raise DSLRuntimeError(
f"Invalid compile options: '{options}'. Please check the option values and format."
)
log().info("`cute.compile` CompileOptions: options=" + options)
def to_str(self):
"""
Generate a string representation of all compilation options
which will be used in pipeline options.
"""
option_strings = []
for key, value in vars(self._options).items():
hyphen_key = key.replace("_", "-")
if isinstance(value, bool):
formatted_value = "true" if value else "false"
else:
formatted_value = str(value)
option_strings.append(f"{hyphen_key}={formatted_value}")
return " ".join(option_strings)
def compile(func, *args, **kwargs):
"""
This function is used to compile a `cute.jit` decorated function.
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
:param func: The function to compile. It can be a regular function, a method or a class instance.
:param args: The arguments to pass to the function.
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
`opt_level` to control the compilation flags.
:return: The jit executor.
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
"""
if func is None:
raise DSLRuntimeError("Function is not set or invalid.")
@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
# process compile options, extract the options and remove them from the kwargs
options = kwargs.pop("options", "")
func._dsl_object.compile_options = CompileOptions(options)
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
return func._dsl_object._func(fcn_ptr, *args, **kwargs)