189 lines
5.8 KiB
Python
189 lines
5.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides runtime utilities for JIT argument conversion in DSL.
|
|
"""
|
|
|
|
from functools import wraps
|
|
from typing import get_origin
|
|
|
|
# Local modules imports
|
|
from ..common import DSLRuntimeError
|
|
from ..typing import (
|
|
Constexpr,
|
|
Int32,
|
|
Float32,
|
|
Boolean,
|
|
)
|
|
|
|
|
|
def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
|
|
"""
|
|
Check if the argument spec is a constexpr.
|
|
"""
|
|
|
|
def _is_reserved_python_func_arg(arg_index, arg_name, func):
|
|
"""
|
|
Check if the argument is a reserved python function argument.
|
|
"""
|
|
|
|
if arg_index != 0:
|
|
return False
|
|
|
|
if arg_name == "self":
|
|
return True
|
|
|
|
is_classmethod = isinstance(func, classmethod) or (
|
|
hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
|
|
)
|
|
return arg_name == "cls" and is_classmethod
|
|
|
|
return (
|
|
_is_reserved_python_func_arg(arg_index, arg_name, owning_func)
|
|
or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
|
|
or (get_origin(arg_spec) is Constexpr)
|
|
)
|
|
|
|
|
|
def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
|
|
"""
|
|
Check if the argument is a constexpr.
|
|
"""
|
|
|
|
def _is_type_argument(arg, arg_annotation):
|
|
"""
|
|
Check if the argument is a type argument like Type[X]
|
|
"""
|
|
|
|
return isinstance(arg, type) and (
|
|
arg_annotation is None or get_origin(arg_annotation) is type
|
|
)
|
|
|
|
return (
|
|
is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
|
|
or _is_type_argument(arg, arg_spec)
|
|
or arg is None
|
|
)
|
|
|
|
|
|
class JitArgAdapterRegistry:
|
|
"""
|
|
A registry to keep track of the JIT argument adapters.
|
|
|
|
An adapter is a callable that converts a Python type to a type with following protocols supported:
|
|
- JitArgument
|
|
- DynamicExpression
|
|
The converted type can then be further processed by DSL to generate arguments for JIT functions.
|
|
"""
|
|
|
|
# A dictionary with key=type and value=callable
|
|
jit_arg_adapter_registry = {}
|
|
|
|
@classmethod
|
|
def register_jit_arg_adapter(cls, *dargs, **dkwargs):
|
|
"""
|
|
Register a JIT argument adapter callable
|
|
|
|
This can be used as a decorator on any callable like:
|
|
|
|
@register_jit_arg_adapter(my_py_type)
|
|
def my_adapter_for_my_py_type(arg):
|
|
...
|
|
|
|
@register_jit_arg_adapter(my_py_type)
|
|
class MyAdapterForMyPythonType:
|
|
...
|
|
|
|
The adapters are registered per type. If a type is already registerd, an error will be raised.
|
|
"""
|
|
|
|
def decorator(*dargs, **dkwargs):
|
|
darg_python_ty = dargs[0]
|
|
|
|
@wraps(darg_python_ty)
|
|
def wrapper(*args, **kwargs):
|
|
if len(args) != 1 or not callable(args[0]):
|
|
raise DSLRuntimeError(
|
|
"a callable must be provided for registering JIT argument adapter"
|
|
)
|
|
adapter = args[0]
|
|
|
|
if darg_python_ty in cls.jit_arg_adapter_registry:
|
|
raise DSLRuntimeError(
|
|
f"JIT argument adapter for {darg_python_ty} is already registered!",
|
|
context={
|
|
"Registered adapter": cls.jit_arg_adapter_registry[
|
|
darg_python_ty
|
|
],
|
|
"Adapter to be registered": adapter,
|
|
},
|
|
)
|
|
cls.jit_arg_adapter_registry[darg_python_ty] = adapter
|
|
return adapter
|
|
|
|
return wrapper
|
|
|
|
if len(dargs) > 0:
|
|
return decorator(*dargs, **dkwargs)
|
|
else:
|
|
raise DSLRuntimeError(
|
|
"a Python type must be provided for registering JIT argument adapter"
|
|
)
|
|
|
|
@classmethod
|
|
def get_registered_adapter(cls, ty):
|
|
"""
|
|
Get the registered JIT argument adapter for the given type.
|
|
"""
|
|
return cls.jit_arg_adapter_registry.get(ty, None)
|
|
|
|
|
|
# =============================================================================
|
|
# JIT Argument Adapters
|
|
# =============================================================================
|
|
|
|
|
|
@JitArgAdapterRegistry.register_jit_arg_adapter(int)
|
|
@JitArgAdapterRegistry.register_jit_arg_adapter(float)
|
|
@JitArgAdapterRegistry.register_jit_arg_adapter(bool)
|
|
def _convert_python_scalar(arg):
|
|
"""
|
|
Convert a Python scalar to a DSL type.
|
|
"""
|
|
conversion_map = {
|
|
int: Int32,
|
|
float: Float32,
|
|
bool: Boolean,
|
|
}
|
|
return conversion_map.get(type(arg))(arg)
|
|
|
|
|
|
@JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
|
|
@JitArgAdapterRegistry.register_jit_arg_adapter(list)
|
|
def _convert_python_sequence(arg):
|
|
"""
|
|
Go through each element in the sequence and convert it to a type that can be
|
|
further processed by DSL to generate the corresponding JIT argument(s).
|
|
"""
|
|
adapted_arg = []
|
|
for elem in arg:
|
|
adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
|
|
if adapter is not None:
|
|
converted_elem = adapter(elem)
|
|
adapted_arg.append(converted_elem)
|
|
else:
|
|
# If no registered adapter is found, just return the original element
|
|
adapted_arg.append(elem)
|
|
|
|
assert len(adapted_arg) == len(arg)
|
|
return type(arg)(adapted_arg)
|