222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides a class that compiles generated IR using MLIR's PassManager
|
|
and executes it using MLIR's ExecutionEngine.
|
|
|
|
"""
|
|
|
|
from typing import Sequence, Optional, Tuple
|
|
import os
|
|
import sys
|
|
import inspect
|
|
from .common import DSLRuntimeError
|
|
|
|
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(_SCRIPT_PATH)
|
|
|
|
from .._mlir import ir
|
|
|
|
|
|
# =============================================================================
|
|
# Compiler Class
|
|
# =============================================================================
|
|
|
|
|
|
class CompilationError(RuntimeError):
|
|
"""Custom error class for compilation failures"""
|
|
|
|
# Add ANSI color codes
|
|
RED = "\033[91m"
|
|
YELLOW = "\033[93m"
|
|
BLUE = "\033[94m"
|
|
GREEN = "\033[92m"
|
|
BOLD = "\033[1m"
|
|
RESET = "\033[0m"
|
|
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
nvvm_error: Optional[str] = None,
|
|
ir_context: Optional[str] = None,
|
|
cuda_toolkit: Optional[str] = None,
|
|
arch: Optional[str] = None,
|
|
):
|
|
self.nvvm_error = nvvm_error
|
|
self.ir_context = ir_context
|
|
self.cuda_toolkit = cuda_toolkit
|
|
self.arch = arch
|
|
# Call parent with formatted error to avoid showing class name
|
|
super().__init__("") # Empty string to avoid class name
|
|
# Store formatted error for str() representation
|
|
self._formatted_error = self._format_error()
|
|
|
|
def __str__(self) -> str:
|
|
"""Override string representation to avoid showing class name"""
|
|
return self._formatted_error
|
|
|
|
def __repr__(self) -> str:
|
|
"""Override repr representation to avoid showing class name"""
|
|
return self._formatted_error
|
|
|
|
def _format_error(self) -> str:
|
|
if not self.nvvm_error:
|
|
return str(self.args[0])
|
|
|
|
return f"""NVVM Compilation Error:
|
|
----------------------
|
|
|
|
{self.BLUE}⚙️ Current Settings:{self.RESET}
|
|
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
|
|
- Target Architecture: {self.arch}{self.RESET}
|
|
|
|
IR Context (truncated):
|
|
{self.ir_context}
|
|
|
|
{self.YELLOW}💡 Possible Solutions:{self.RESET}
|
|
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
|
|
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
|
|
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
|
|
|
|
|
|
class Compiler:
|
|
"""Compiler class for compiling and building MLIR modules."""
|
|
|
|
def __init__(self, passmanager, execution_engine):
|
|
self.passmanager = passmanager
|
|
self.execution_engine = execution_engine
|
|
|
|
def __call__(self, module):
|
|
"""Convenience application method."""
|
|
self.compile(module)
|
|
|
|
def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Process error message to extract NVVM error and IR context"""
|
|
nvvm_error = None
|
|
ir_msg = ""
|
|
|
|
if "NVVM_ERROR" in error_msg:
|
|
# Extract the specific NVVM error
|
|
nvvm_error = (
|
|
error_msg.split("libNVVM extra log:")[1].strip()
|
|
if "libNVVM extra log:" in error_msg
|
|
else error_msg
|
|
)
|
|
|
|
# Extract IR context
|
|
if "see current operation:" in error_msg:
|
|
# Get the IR section
|
|
ir_section = error_msg.split("see current operation:")[1].strip()
|
|
# Remove duplicate IR section
|
|
ir_section = ir_section.split("error: unknown: Failed translating")[
|
|
0
|
|
].strip()
|
|
|
|
# Get first few lines and last few lines of the IR
|
|
ir_lines = ir_section.split("\n")
|
|
if len(ir_lines) > 10:
|
|
ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
|
|
else:
|
|
ir_msg = ir_section
|
|
|
|
return nvvm_error, ir_msg
|
|
|
|
def compile(
|
|
self,
|
|
module,
|
|
pipeline: str,
|
|
cuda_toolkit: str = "",
|
|
arch: str = "",
|
|
enable_verifier=False,
|
|
):
|
|
"""Compiles the module by invoking the pipeline."""
|
|
try:
|
|
pm = self.passmanager.PassManager.parse(pipeline)
|
|
pm.enable_verifier(enable_verifier)
|
|
pm.run(module.operation)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
nvvm_error, ir_msg = self._process_error(error_msg)
|
|
|
|
if nvvm_error:
|
|
raise CompilationError(
|
|
error_msg,
|
|
nvvm_error=nvvm_error,
|
|
ir_context=ir_msg,
|
|
cuda_toolkit=cuda_toolkit,
|
|
arch=arch,
|
|
) from e
|
|
raise e
|
|
|
|
def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
|
|
"""Wraps the module in a JIT execution engine."""
|
|
return self.execution_engine.ExecutionEngine(
|
|
module, opt_level=opt_level, shared_libs=shared_libs
|
|
)
|
|
|
|
def compile_and_jit(
|
|
self,
|
|
module,
|
|
pipeline: str,
|
|
shared_libs: Sequence[str] = (),
|
|
opt_level: int = 2,
|
|
cuda_toolkit: str = "",
|
|
arch: str = "",
|
|
):
|
|
"""Compiles and jits the module."""
|
|
self.compile(
|
|
module,
|
|
pipeline,
|
|
cuda_toolkit,
|
|
arch,
|
|
)
|
|
return self.jit(module, opt_level, shared_libs)
|
|
|
|
|
|
def compile(func, *args, **kwargs):
|
|
if func is None:
|
|
raise DSLRuntimeError("Function is not set or invalid.")
|
|
|
|
if not callable(func):
|
|
raise DSLRuntimeError("Object is not callable.")
|
|
|
|
kwargs["compile_only"] = True
|
|
kwargs["no_cache"] = True
|
|
|
|
if inspect.isfunction(func):
|
|
# regular function
|
|
pass
|
|
elif inspect.ismethod(func):
|
|
# if it's a method, add the instance to the first argument
|
|
args = [func.__self__] + list(args)
|
|
func = func.__func__
|
|
elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
|
|
# If it's a class instance, get the class's __call__ method
|
|
args = [func] + list(args)
|
|
# Get the actual function from the class definition
|
|
func = func.__call__.__func__
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"Invalid function type, only function, method and module are supported, but got",
|
|
func,
|
|
)
|
|
|
|
# If it's a wrapped function created by jit decorator, get the original function
|
|
if hasattr(func, "__wrapped__"):
|
|
func = func.__wrapped__
|
|
|
|
if not hasattr(func, "_dsl_object"):
|
|
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
|
|
|
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
|
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|