Files
cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py
2025-05-13 15:55:29 -04:00

189 lines
5.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides runtime utilities for JIT argument conversion in DSL.
"""
from functools import wraps
from typing import get_origin
# Local modules imports
from ..common import DSLRuntimeError
from ..typing import (
Constexpr,
Int32,
Float32,
Boolean,
)
def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
"""
Check if the argument spec is a constexpr.
"""
def _is_reserved_python_func_arg(arg_index, arg_name, func):
"""
Check if the argument is a reserved python function argument.
"""
if arg_index != 0:
return False
if arg_name == "self":
return True
is_classmethod = isinstance(func, classmethod) or (
hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
)
return arg_name == "cls" and is_classmethod
return (
_is_reserved_python_func_arg(arg_index, arg_name, owning_func)
or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
or (get_origin(arg_spec) is Constexpr)
)
def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
"""
Check if the argument is a constexpr.
"""
def _is_type_argument(arg, arg_annotation):
"""
Check if the argument is a type argument like Type[X]
"""
return isinstance(arg, type) and (
arg_annotation is None or get_origin(arg_annotation) is type
)
return (
is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
or _is_type_argument(arg, arg_spec)
or arg is None
)
class JitArgAdapterRegistry:
"""
A registry to keep track of the JIT argument adapters.
An adapter is a callable that converts a Python type to a type with following protocols supported:
- JitArgument
- DynamicExpression
The converted type can then be further processed by DSL to generate arguments for JIT functions.
"""
# A dictionary with key=type and value=callable
jit_arg_adapter_registry = {}
@classmethod
def register_jit_arg_adapter(cls, *dargs, **dkwargs):
"""
Register a JIT argument adapter callable
This can be used as a decorator on any callable like:
@register_jit_arg_adapter(my_py_type)
def my_adapter_for_my_py_type(arg):
...
@register_jit_arg_adapter(my_py_type)
class MyAdapterForMyPythonType:
...
The adapters are registered per type. If a type is already registerd, an error will be raised.
"""
def decorator(*dargs, **dkwargs):
darg_python_ty = dargs[0]
@wraps(darg_python_ty)
def wrapper(*args, **kwargs):
if len(args) != 1 or not callable(args[0]):
raise DSLRuntimeError(
"a callable must be provided for registering JIT argument adapter"
)
adapter = args[0]
if darg_python_ty in cls.jit_arg_adapter_registry:
raise DSLRuntimeError(
f"JIT argument adapter for {darg_python_ty} is already registered!",
context={
"Registered adapter": cls.jit_arg_adapter_registry[
darg_python_ty
],
"Adapter to be registered": adapter,
},
)
cls.jit_arg_adapter_registry[darg_python_ty] = adapter
return adapter
return wrapper
if len(dargs) > 0:
return decorator(*dargs, **dkwargs)
else:
raise DSLRuntimeError(
"a Python type must be provided for registering JIT argument adapter"
)
@classmethod
def get_registered_adapter(cls, ty):
"""
Get the registered JIT argument adapter for the given type.
"""
return cls.jit_arg_adapter_registry.get(ty, None)
# =============================================================================
# JIT Argument Adapters
# =============================================================================
@JitArgAdapterRegistry.register_jit_arg_adapter(int)
@JitArgAdapterRegistry.register_jit_arg_adapter(float)
@JitArgAdapterRegistry.register_jit_arg_adapter(bool)
def _convert_python_scalar(arg):
"""
Convert a Python scalar to a DSL type.
"""
conversion_map = {
int: Int32,
float: Float32,
bool: Boolean,
}
return conversion_map.get(type(arg))(arg)
@JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
@JitArgAdapterRegistry.register_jit_arg_adapter(list)
def _convert_python_sequence(arg):
"""
Go through each element in the sequence and convert it to a type that can be
further processed by DSL to generate the corresponding JIT argument(s).
"""
adapted_arg = []
for elem in arg:
adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
if adapter is not None:
converted_elem = adapter(elem)
adapted_arg.append(converted_elem)
else:
# If no registered adapter is found, just return the original element
adapted_arg.append(elem)
assert len(adapted_arg) == len(arg)
return type(arg)(adapted_arg)