287 lines
9.9 KiB
Python
287 lines
9.9 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides a class that compiles generated IR using MLIR's PassManager
|
|
and executes it using MLIR's ExecutionEngine.
|
|
|
|
"""
|
|
|
|
from typing import Sequence, Optional, Tuple
|
|
import os
|
|
import sys
|
|
import inspect
|
|
import argparse
|
|
from .common import DSLRuntimeError
|
|
from .utils.logger import log
|
|
|
|
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(_SCRIPT_PATH)
|
|
|
|
from .._mlir import ir
|
|
|
|
|
|
# =============================================================================
|
|
# Compiler Class
|
|
# =============================================================================
|
|
|
|
|
|
class CompilationError(RuntimeError):
|
|
"""Custom error class for compilation failures"""
|
|
|
|
# Add ANSI color codes
|
|
RED = "\033[91m"
|
|
YELLOW = "\033[93m"
|
|
BLUE = "\033[94m"
|
|
GREEN = "\033[92m"
|
|
BOLD = "\033[1m"
|
|
RESET = "\033[0m"
|
|
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
nvvm_error: Optional[str] = None,
|
|
ir_context: Optional[str] = None,
|
|
cuda_toolkit: Optional[str] = None,
|
|
arch: Optional[str] = None,
|
|
):
|
|
self.nvvm_error = nvvm_error
|
|
self.ir_context = ir_context
|
|
self.cuda_toolkit = cuda_toolkit
|
|
self.arch = arch
|
|
# Call parent with formatted error to avoid showing class name
|
|
super().__init__("") # Empty string to avoid class name
|
|
# Store formatted error for str() representation
|
|
self._formatted_error = self._format_error()
|
|
|
|
def __str__(self) -> str:
|
|
"""Override string representation to avoid showing class name"""
|
|
return self._formatted_error
|
|
|
|
def __repr__(self) -> str:
|
|
"""Override repr representation to avoid showing class name"""
|
|
return self._formatted_error
|
|
|
|
def _format_error(self) -> str:
|
|
if not self.nvvm_error:
|
|
return str(self.args[0])
|
|
|
|
return f"""NVVM Compilation Error:
|
|
----------------------
|
|
|
|
{self.BLUE}⚙️ Current Settings:{self.RESET}
|
|
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
|
|
- Target Architecture: {self.arch}{self.RESET}
|
|
|
|
IR Context (truncated):
|
|
{self.ir_context}
|
|
|
|
{self.YELLOW}💡 Possible Solutions:{self.RESET}
|
|
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
|
|
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
|
|
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
|
|
|
|
|
|
class Compiler:
|
|
"""Compiler class for compiling and building MLIR modules."""
|
|
|
|
def __init__(self, passmanager, execution_engine):
|
|
self.passmanager = passmanager
|
|
self.execution_engine = execution_engine
|
|
|
|
def __call__(self, module):
|
|
"""Convenience application method."""
|
|
self.compile(module)
|
|
|
|
def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Process error message to extract NVVM error and IR context"""
|
|
nvvm_error = None
|
|
ir_msg = ""
|
|
|
|
if "NVVM_ERROR" in error_msg:
|
|
# Extract the specific NVVM error
|
|
nvvm_error = (
|
|
error_msg.split("libNVVM extra log:")[1].strip()
|
|
if "libNVVM extra log:" in error_msg
|
|
else error_msg
|
|
)
|
|
|
|
# Extract IR context
|
|
if "see current operation:" in error_msg:
|
|
# Get the IR section
|
|
ir_section = error_msg.split("see current operation:")[1].strip()
|
|
# Remove duplicate IR section
|
|
ir_section = ir_section.split("error: unknown: Failed translating")[
|
|
0
|
|
].strip()
|
|
|
|
# Get first few lines and last few lines of the IR
|
|
ir_lines = ir_section.split("\n")
|
|
if len(ir_lines) > 10:
|
|
ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
|
|
else:
|
|
ir_msg = ir_section
|
|
|
|
return nvvm_error, ir_msg
|
|
|
|
def compile(
|
|
self,
|
|
module,
|
|
pipeline: str,
|
|
cuda_toolkit: str = "",
|
|
arch: str = "",
|
|
enable_verifier=False,
|
|
):
|
|
"""Compiles the module by invoking the pipeline."""
|
|
try:
|
|
pm = self.passmanager.PassManager.parse(pipeline)
|
|
pm.enable_verifier(enable_verifier)
|
|
pm.run(module.operation)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
nvvm_error, ir_msg = self._process_error(error_msg)
|
|
|
|
if nvvm_error:
|
|
raise CompilationError(
|
|
error_msg,
|
|
nvvm_error=nvvm_error,
|
|
ir_context=ir_msg,
|
|
cuda_toolkit=cuda_toolkit,
|
|
arch=arch,
|
|
) from e
|
|
raise e
|
|
|
|
def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
|
|
"""Wraps the module in a JIT execution engine."""
|
|
return self.execution_engine.ExecutionEngine(
|
|
module, opt_level=opt_level, shared_libs=shared_libs
|
|
)
|
|
|
|
def compile_and_jit(
|
|
self,
|
|
module,
|
|
pipeline: str,
|
|
shared_libs: Sequence[str] = (),
|
|
opt_level: int = 2,
|
|
cuda_toolkit: str = "",
|
|
arch: str = "",
|
|
):
|
|
"""Compiles and jits the module."""
|
|
self.compile(
|
|
module,
|
|
pipeline,
|
|
cuda_toolkit,
|
|
arch,
|
|
)
|
|
return self.jit(module, opt_level, shared_libs)
|
|
|
|
|
|
class CompileOptions:
|
|
def __init__(self, options: str = ""):
|
|
"""
|
|
This class encapsulates all compilation options relevant to function compilation.
|
|
It provides a convenient way to manage and pass compilation options,
|
|
particularly for controlling compilation settings.
|
|
By centralizing these options, it ensures consistent and flexible configuration of
|
|
compilation parameters such as optimization level, debugging control, etc.
|
|
|
|
:param options: The options for the function. Will be parsed by argparse.
|
|
:type options: str
|
|
"""
|
|
if not isinstance(options, str):
|
|
raise DSLRuntimeError(
|
|
f"Invalid compilation `options`: {options}, it should be a string"
|
|
)
|
|
self._parser = argparse.ArgumentParser()
|
|
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
|
|
self._parser.add_argument(
|
|
"--enable-device-assertions", action="store_true", default=False
|
|
)
|
|
try:
|
|
self._options = self._parser.parse_args(options.split())
|
|
except SystemExit as e:
|
|
# catch argparse error and raise as DSLRuntimeError
|
|
raise DSLRuntimeError(
|
|
f"Invalid compile options: '{options}'. Please check the option values and format."
|
|
)
|
|
log().info("`cute.compile` CompileOptions: options=" + options)
|
|
|
|
def to_str(self):
|
|
"""
|
|
Generate a string representation of all compilation options
|
|
which will be used in pipeline options.
|
|
"""
|
|
option_strings = []
|
|
for key, value in vars(self._options).items():
|
|
hyphen_key = key.replace("_", "-")
|
|
if isinstance(value, bool):
|
|
formatted_value = "true" if value else "false"
|
|
else:
|
|
formatted_value = str(value)
|
|
option_strings.append(f"{hyphen_key}={formatted_value}")
|
|
|
|
return " ".join(option_strings)
|
|
|
|
|
|
def compile(func, *args, **kwargs):
|
|
"""
|
|
This function is used to compile a `cute.jit` decorated function.
|
|
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
|
|
|
|
:param func: The function to compile. It can be a regular function, a method or a class instance.
|
|
:param args: The arguments to pass to the function.
|
|
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
|
|
`opt_level` to control the compilation flags.
|
|
|
|
:return: The jit executor.
|
|
|
|
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
|
|
"""
|
|
if func is None:
|
|
raise DSLRuntimeError("Function is not set or invalid.")
|
|
|
|
if not callable(func):
|
|
raise DSLRuntimeError("Object is not callable.")
|
|
|
|
kwargs["compile_only"] = True
|
|
kwargs["no_cache"] = True
|
|
|
|
if inspect.isfunction(func):
|
|
# regular function
|
|
pass
|
|
elif inspect.ismethod(func):
|
|
# if it's a method, add the instance to the first argument
|
|
args = [func.__self__] + list(args)
|
|
func = func.__func__
|
|
elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
|
|
# If it's a class instance, get the class's __call__ method
|
|
args = [func] + list(args)
|
|
# Get the actual function from the class definition
|
|
func = func.__call__.__func__
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"Invalid function type, only function, method and module are supported, but got",
|
|
func,
|
|
)
|
|
|
|
# If it's a wrapped function created by jit decorator, get the original function
|
|
if hasattr(func, "__wrapped__"):
|
|
func = func.__wrapped__
|
|
|
|
if not hasattr(func, "_dsl_object"):
|
|
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
|
|
|
# process compile options, extract the options and remove them from the kwargs
|
|
options = kwargs.pop("options", "")
|
|
func._dsl_object.compile_options = CompileOptions(options)
|
|
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
|
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|