358 lines
14 KiB
Python
358 lines
14 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides jit executor related classes
|
|
"""
|
|
import ctypes
|
|
import inspect
|
|
import io
|
|
from typing import get_origin
|
|
|
|
import numpy as np
|
|
|
|
# MLIR modules imports
|
|
from .._mlir import ir
|
|
|
|
# Local modules imports
|
|
from . import typing as t
|
|
from .common import DSLRuntimeError
|
|
from .runtime import cuda as cuda_helpers
|
|
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
|
|
from .typing import get_c_pointers
|
|
from .utils.logger import log
|
|
from .utils.timer import timer
|
|
|
|
|
|
class CudaSingleModule:
|
|
def __init__(self, cuda_module, kernel_ptr):
|
|
self.cuda_module = cuda_module
|
|
self.kernel_ptr = kernel_ptr
|
|
|
|
|
|
class CudaModules:
|
|
def __init__(self, modules, args):
|
|
# list of CudaSingleModule
|
|
self.modules = modules
|
|
# extra kernel ptr arguments for launch
|
|
self.args = args
|
|
|
|
|
|
class JitExecutor:
|
|
def __init__(
|
|
self,
|
|
dsl,
|
|
engine,
|
|
capi_func,
|
|
ir_module,
|
|
args_spec,
|
|
function_name,
|
|
cuda_modules: CudaModules = None,
|
|
jit_time_profiling=False,
|
|
):
|
|
self.dsl = dsl
|
|
self.engine = engine
|
|
self.capi_func = capi_func
|
|
self.ir_module = ir_module
|
|
self.args_spec = args_spec
|
|
self.function_name = function_name
|
|
if args_spec is not None:
|
|
self.original_args_spec = args_spec
|
|
self.args_spec = self.filter_runtime_arg_spec(args_spec)
|
|
# cuda kernels
|
|
self.cuda_modules = cuda_modules
|
|
self.jit_time_profiling = jit_time_profiling
|
|
|
|
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
|
|
runtime_args = []
|
|
runtime_annotations = {}
|
|
runtime_defaults = []
|
|
|
|
# Calculate the offset where defaults start in the original args
|
|
if arg_spec.defaults:
|
|
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
|
|
else:
|
|
defaults_start_idx = len(arg_spec.args)
|
|
|
|
# Filter arguments and maintain their properties
|
|
for i, arg_name in enumerate(arg_spec.args):
|
|
arg_type = arg_spec.annotations.get(arg_name, None)
|
|
|
|
# Skip compile-time arguments
|
|
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
|
continue
|
|
|
|
# Keep runtime arguments
|
|
runtime_args.append(arg_name)
|
|
if arg_name in arg_spec.annotations:
|
|
runtime_annotations[arg_name] = arg_type
|
|
|
|
# Keep corresponding default if it exists
|
|
if i >= defaults_start_idx:
|
|
default_idx = i - defaults_start_idx
|
|
runtime_defaults.append(arg_spec.defaults[default_idx])
|
|
|
|
# Filter kwonlyargs and their defaults
|
|
runtime_kwonlyargs = []
|
|
runtime_kwonlydefaults = {}
|
|
|
|
if arg_spec.kwonlyargs:
|
|
for kwarg in arg_spec.kwonlyargs:
|
|
arg_type = arg_spec.annotations.get(kwarg, None)
|
|
|
|
# Apply same filtering logic
|
|
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
|
|
continue
|
|
|
|
runtime_kwonlyargs.append(kwarg)
|
|
if kwarg in arg_spec.annotations:
|
|
runtime_annotations[kwarg] = arg_type
|
|
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
|
|
runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
|
|
|
|
# Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
|
|
runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
|
|
|
|
return inspect.FullArgSpec(
|
|
args=runtime_args,
|
|
varargs=arg_spec.varargs, # Keep original varargs
|
|
varkw=arg_spec.varkw, # Keep original varkw
|
|
defaults=runtime_defaults,
|
|
kwonlyargs=runtime_kwonlyargs,
|
|
kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
|
|
annotations=runtime_annotations,
|
|
)
|
|
|
|
def __del__(self):
|
|
if self.cuda_modules:
|
|
cuda_modules = [module.cuda_module for module in self.cuda_modules.modules]
|
|
for module in set(cuda_modules):
|
|
cuda_helpers.unload_cubin_module(module)
|
|
|
|
def get_constexpr_args(self) -> list[dict[str, int | str]]:
|
|
"""
|
|
This function returns the constexpr args that have been pruned from the original function signature.
|
|
The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
|
|
|
:return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
|
:rtype: list[dict[str, int | str]]
|
|
"""
|
|
if self.original_args_spec is None:
|
|
return list()
|
|
constexpr_args = list()
|
|
for i, arg_name in enumerate(self.original_args_spec.args):
|
|
if arg_name not in self.args_spec.args:
|
|
constexpr_args.append({"argument_index": i, "argument_name": arg_name})
|
|
|
|
if self.original_args_spec.kwonlyargs:
|
|
for kwarg in self.original_args_spec.kwonlyargs:
|
|
if kwarg not in self.args_spec.kwonlyargs:
|
|
constexpr_args.append(
|
|
{"argument_index": None, "argument_name": kwarg}
|
|
)
|
|
return constexpr_args
|
|
|
|
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
|
|
"""
|
|
This function is the prune version of `generate_mlir_function_types` which only generates execution args
|
|
to get rid of mlir context.
|
|
"""
|
|
|
|
# Process positional arguments with defaults
|
|
rectified_args = list(args)
|
|
if args_spec.defaults and len(args) < len(args_spec.args):
|
|
rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :])
|
|
for k, v in kwargs.items():
|
|
if k in args_spec.args:
|
|
idx = args_spec.args.index(k)
|
|
if idx < len(rectified_args):
|
|
rectified_args[idx] = v
|
|
else:
|
|
rectified_args.append(v)
|
|
|
|
# Process keyword arguments
|
|
rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args}
|
|
if args_spec.kwonlydefaults and len(rectified_kwargs) < len(
|
|
args_spec.kwonlyargs
|
|
):
|
|
rectified_kwargs.update(args_spec.kwonlydefaults)
|
|
|
|
# args/kwargs must match arg_specs
|
|
if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len(
|
|
args_spec.kwonlyargs
|
|
):
|
|
raise DSLRuntimeError(
|
|
"input args/kwargs length does not match runtime function signature!",
|
|
context={
|
|
"input args length": len(rectified_args),
|
|
"input kwargs length": len(rectified_kwargs),
|
|
"function signature args length": len(args_spec.args),
|
|
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
|
|
},
|
|
)
|
|
|
|
exe_args = []
|
|
adapted_args = []
|
|
input_args = rectified_args + list(rectified_kwargs.values())
|
|
input_arg_names = args_spec.args + args_spec.kwonlyargs
|
|
for arg, arg_name in zip(input_args, input_arg_names):
|
|
# short-cut for args already converted
|
|
if hasattr(arg, "__c_pointers__"):
|
|
exe_args.extend(arg.__c_pointers__())
|
|
continue
|
|
|
|
arg_type = args_spec.annotations.get(arg_name, None)
|
|
|
|
# Implicit cast to NumericMeta
|
|
if isinstance(arg_type, t.NumericMeta):
|
|
arg = t.cast(arg, arg_type)
|
|
else:
|
|
# If not any known type, try registered adapter to do the conversion
|
|
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
|
if adapter:
|
|
arg = adapter(arg)
|
|
adapted_args.append(arg)
|
|
|
|
exe_args.extend(get_c_pointers(arg))
|
|
|
|
return exe_args, adapted_args
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
exe_args, adapted_args = self.generate_execution_args(
|
|
args, kwargs, self.args_spec
|
|
)
|
|
|
|
self.run_compiled_program(exe_args)
|
|
|
|
# Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
|
|
def get_invoke_packed_args(self, exe_args):
|
|
if self.cuda_modules:
|
|
exe_args += self.cuda_modules.args
|
|
packed_args = (ctypes.c_void_p * len(exe_args))()
|
|
for argNum in range(len(exe_args)):
|
|
packed_args[argNum] = exe_args[argNum]
|
|
return packed_args
|
|
|
|
def run_compiled_program(self, exe_args):
|
|
if self.jit_time_profiling:
|
|
profiler = timer(enable=True)
|
|
try:
|
|
packed_args = profiler(self.get_invoke_packed_args)(exe_args)
|
|
profiler(self.capi_func)(packed_args)
|
|
except Exception as e:
|
|
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
|
else:
|
|
try:
|
|
packed_args = self.get_invoke_packed_args(exe_args)
|
|
self.capi_func(packed_args)
|
|
except Exception as e:
|
|
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
|
|
|
def update_jit_cuda_modules(self, kernel_symbols):
|
|
# preload cuda module from compiled cubin in ir and store to jit_executor.kernels.
|
|
if len(kernel_symbols) > 0:
|
|
extra_args = []
|
|
module = self.ir_module
|
|
cuda_kernel_cache = dict()
|
|
cuda_driver_version = cuda_helpers.get_driver_version()
|
|
for sym in kernel_symbols:
|
|
if sym not in cuda_kernel_cache:
|
|
log().debug(f"Loading CUDA module for symbol: {sym}")
|
|
|
|
# load cuda module/get function pointer from module and cache
|
|
def walk_callback(sym, func_sym, cubin_data):
|
|
cubin_module = cuda_helpers.load_cubin_module_data(cubin_data)
|
|
kernel_ptr = cuda_helpers.get_kernel_function(
|
|
cubin_module, func_sym
|
|
)
|
|
# Enable non-portable cluster size for CUDA version 11.8 or higher.
|
|
if cuda_driver_version >= 11080:
|
|
cuda_helpers.set_kernel_attribute(
|
|
kernel_ptr,
|
|
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
|
|
1,
|
|
)
|
|
cuda_kernel_cache[sym] = CudaSingleModule(
|
|
cubin_module, kernel_ptr
|
|
)
|
|
|
|
self.walk_module_and_get_cubin_data(module, sym, walk_callback)
|
|
else:
|
|
log().debug(f"Symbol {sym} already in cache")
|
|
# check if kernel is empty.
|
|
if sym in cuda_kernel_cache:
|
|
extra_args.append(
|
|
ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr())
|
|
)
|
|
# store to the jit result if jit result is cached.
|
|
self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args)
|
|
|
|
return self
|
|
|
|
def _get_escaped_cubin_bytes(self, cubin_data):
|
|
"""This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
|
|
|
|
def ishex(inp):
|
|
return (
|
|
inp in range(0x30, 0x3A)
|
|
or inp in range(0x61, 0x67)
|
|
or inp in range(0x41, 0x47)
|
|
)
|
|
|
|
converted = bytearray()
|
|
idx = 0
|
|
while idx < len(cubin_data):
|
|
# escape the original bytes
|
|
if cubin_data[idx] == 0x5C:
|
|
# if data of idx is b'\\'
|
|
if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
|
|
converted += bytearray.fromhex(
|
|
cubin_data[idx + 1 : idx + 3].decode()
|
|
)
|
|
idx += 3
|
|
elif cubin_data[idx + 1] == 0x5C:
|
|
converted.append(cubin_data[idx])
|
|
idx += 2
|
|
else:
|
|
# no escape, directly write
|
|
converted.append(cubin_data[idx])
|
|
idx += 1
|
|
return bytes(converted)
|
|
|
|
def walk_module_and_get_cubin_data(self, module, sym, callback):
|
|
"""This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
|
|
|
|
def walk_gpu_binary_op(op):
|
|
if op.name != "gpu.binary":
|
|
return ir.WalkResult.ADVANCE
|
|
s = io.BytesIO()
|
|
op.write_bytecode(s)
|
|
cubin_data = s.getvalue()
|
|
if sym.encode() not in cubin_data:
|
|
return ir.WalkResult.ADVANCE
|
|
|
|
if (
|
|
"kernels" != op.opview.sym_name.value
|
|
and sym != op.opview.sym_name.value
|
|
):
|
|
return ir.WalkResult.ADVANCE
|
|
# function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
|
|
func_sym = sym
|
|
if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
|
|
func_sym = sym.rsplit("_", 1)[0]
|
|
|
|
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
|
|
cubin_data = self._get_escaped_cubin_bytes(cubin_data)
|
|
callback(sym, func_sym, cubin_data)
|
|
return ir.WalkResult.ADVANCE
|
|
|
|
module.operation.walk(walk_gpu_binary_op)
|