1963 lines
63 KiB
Python
1963 lines
63 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
import ctypes
|
|
import numpy as np
|
|
import operator
|
|
from typing_extensions import deprecated
|
|
from functools import reduce
|
|
from typing import (
|
|
Generic,
|
|
Protocol,
|
|
Union,
|
|
Any,
|
|
List,
|
|
Type,
|
|
TypeVar,
|
|
overload,
|
|
runtime_checkable,
|
|
get_origin,
|
|
)
|
|
from types import FunctionType
|
|
from dataclasses import dataclass
|
|
from abc import ABC, abstractmethod
|
|
|
|
from .common import *
|
|
from .ast_helpers import const_expr
|
|
from ._mlir_helpers import arith as arith_helper, lru_cache_ir
|
|
from ._mlir_helpers.arith import ArithValue
|
|
|
|
from .._mlir import ir
|
|
from .._mlir.extras import types as T
|
|
from .._mlir.dialects import arith, math
|
|
|
|
# =============================================================================
|
|
# Dynamic Expression Protocol
|
|
# =============================================================================
|
|
|
|
|
|
@runtime_checkable
|
|
class DynamicExpression(Protocol):
|
|
"""Protocol defining the interface for object holding dynamic values in the DSL.
|
|
|
|
This protocol enables classes to represent dynamic values in the DSL. Classes implementing
|
|
this protocol can be used in JIT-compiled functions and dynamic value generation.
|
|
|
|
It is required for custom data types to work correctly with following JIT features:
|
|
* as function argument to call another JIT function from JIT function
|
|
* as return value from JIT function
|
|
* for constructions like if-else, while-loop, etc.
|
|
|
|
:param value: The MLIR operation result value to initialize the object with
|
|
:type value: ir.Value
|
|
|
|
**Required Methods**
|
|
|
|
* ``__extract_mlir_values__``: Extract MLIR values from the object
|
|
* ``__new_from_mlir_values__``: Create new instance from MLIR values
|
|
|
|
**Implementation Example**
|
|
|
|
To implement a custom data type that works with the DSL:
|
|
|
|
.. code-block:: python
|
|
|
|
class CustomData(metaclass=DslType):
|
|
def __init__(self, int_value):
|
|
self.int_value = int_value
|
|
|
|
def __extract_mlir_values__(self):
|
|
return [self.int_value]
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
return CustomData(values[0])
|
|
|
|
**Usage in JIT Functions**
|
|
|
|
When used in JIT-compiled functions, the DSL automatically extracts MLIR values:
|
|
|
|
.. code-block:: python
|
|
|
|
@jit
|
|
def caller():
|
|
x = CustomData(1)
|
|
return foo(x)
|
|
|
|
This generates MLIR like:
|
|
|
|
.. code-block:: mlir
|
|
|
|
func @caller() -> i32 {
|
|
%0 = func.call @foo(%arg0) : (i32) -> i32
|
|
return %0 : i32
|
|
}
|
|
"""
|
|
|
|
def __extract_mlir_values__(self):
|
|
"""Extract MLIR values from this object.
|
|
|
|
:return: List of MLIR values representing this object's data
|
|
:rtype: List[ir.Value]
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
"""Create a new instance from MLIR values.
|
|
|
|
:param values: List of MLIR values to construct the object from
|
|
:type values: List[ir.Value]
|
|
:return: New instance of the implementing class
|
|
:rtype: Any
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@runtime_checkable
|
|
class JitArgument(Protocol):
|
|
"""
|
|
Protocol class defining the interface for JIT function argument generation.
|
|
|
|
This protocol enables classes to provide the necessary information for generating
|
|
JIT function arguments and allow the DSL JIT executor to call JIT compiled functions.
|
|
|
|
**Required Methods**
|
|
|
|
* ``__c_pointers__``: Returns ctypes pointers for runtime execution
|
|
* ``__get_mlir_types__``: Returns MLIR types for function definition
|
|
* ``__new_from_mlir_values__``: Creates new instances from MLIR values
|
|
|
|
**Example**
|
|
|
|
.. code-block:: python
|
|
|
|
class CustomData:
|
|
def __init__(self, int_value, ...):
|
|
self.int_value = int_value
|
|
...
|
|
|
|
def __c_pointers__(self):
|
|
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
|
|
|
|
def __get_mlir_types__(self):
|
|
return [ir.IntegerType.get(32), ...]
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
return CustomData(values[0], ...)
|
|
|
|
@jit
|
|
def foo(x: CustomData):
|
|
a = x.int_value + 1
|
|
...
|
|
|
|
# `CustomData` is an argument of `foo`
|
|
foo(CustomData(1, ...))
|
|
|
|
When called like ``y = foo(x)``, the following steps occur:
|
|
|
|
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``
|
|
|
|
.. code-block:: mlir
|
|
|
|
func.func @foo(%arg0: i32, ...) {
|
|
...
|
|
|
|
return
|
|
}
|
|
|
|
2. JIT function can't use values from Python, so it needs to reconstruct the object from
|
|
MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`.
|
|
|
|
Following code demonstrates how JIT compiler reconstructs the object and pass to Python.
|
|
|
|
.. code-block:: python
|
|
|
|
# Implementation of IR tracing
|
|
new_x = CustomData(ir.Value(%arg0), ...)
|
|
y = foo(new_x)
|
|
# `x.int_value` is %arg0 rather than `c1` defined by Python.
|
|
|
|
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``
|
|
pointing to the underlying data object passing to JIT compiled function.
|
|
|
|
.. code-block:: python
|
|
|
|
jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...]))
|
|
"""
|
|
|
|
def __c_pointers__(self):
|
|
"""
|
|
Generate a list of ctypes pointers for the current object.
|
|
|
|
:return: List of ctypes pointers
|
|
:rtype: List[ctypes.c_void_p]
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __get_mlir_types__(self):
|
|
"""
|
|
Generate a list of MLIR types for the current object.
|
|
|
|
:return: List of MLIR types
|
|
:rtype: List[ir.Type]
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
"""
|
|
Create a new object from MLIR values.
|
|
|
|
:param values: List of MLIR values
|
|
:type values: List[ir.Value]
|
|
:return: A new object that represents the given MLIR values
|
|
:rtype: Any
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_c_pointers(obj):
|
|
"""
|
|
Given the `obj`, recursively go through it to extract all contained C pointers
|
|
"""
|
|
if hasattr(obj, "__c_pointers__"):
|
|
return obj.__c_pointers__()
|
|
elif isinstance(obj, (tuple, list)):
|
|
return sum((get_c_pointers(x) for x in obj), [])
|
|
elif isinstance(obj, set):
|
|
raise DSLRuntimeError(
|
|
"Sets are not supported in get_c_pointers to ensure order preservation",
|
|
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
|
suggestion="Consider using a list or tuple instead",
|
|
)
|
|
return []
|
|
|
|
|
|
def get_mlir_types(obj):
|
|
"""
|
|
Given the `obj`, recursively go through it to extract all contained MLIR types
|
|
"""
|
|
if hasattr(obj, "__get_mlir_types__"):
|
|
return obj.__get_mlir_types__()
|
|
elif hasattr(obj, "__extract_mlir_values__"):
|
|
return [v.type for v in obj.__extract_mlir_values__()]
|
|
elif isinstance(obj, ir.Value):
|
|
return [obj.type]
|
|
elif isinstance(obj, (tuple, list)):
|
|
return sum((get_mlir_types(x) for x in obj), [])
|
|
elif isinstance(obj, set):
|
|
raise DSLRuntimeError(
|
|
"Sets are not supported in get_mlir_types to ensure order preservation",
|
|
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
|
suggestion="Consider using a list or tuple instead",
|
|
)
|
|
return []
|
|
|
|
|
|
class DslType(type):
|
|
"""Metaclass for all DSL types in the system.
|
|
|
|
This metaclass provides type system infrastructure for DSL types, handling MLIR
|
|
type mappings and NumPy type conversions.
|
|
|
|
All data types in DSL must provide the following methods:
|
|
|
|
:param mlir_type: Corresponding MLIR type for this DSL type
|
|
:type mlir_type: Any, optional
|
|
:param is_abstract: Whether this type is abstract, defaults to False
|
|
:type is_abstract: bool, optional
|
|
|
|
**Required Methods**
|
|
|
|
* ``__str__`` (classmethod): Return string representation of the type
|
|
* ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function
|
|
* ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance
|
|
* ``__extract_mlir_values__``: Return list of MLIR values contained in the instance
|
|
* ``__new_from_mlir_values__``: Return a new instance from list of MLIR values
|
|
|
|
**Attributes**
|
|
|
|
:ivar _ir: MLIR provider
|
|
:vartype _ir: Any
|
|
:ivar _T: MLIR Type system provider
|
|
:vartype _T: Any
|
|
|
|
**Properties**
|
|
|
|
:property mlir_type: Returns the corresponding MLIR type for this DSL type
|
|
:type mlir_type: Any
|
|
|
|
"""
|
|
|
|
_is_abstract: bool
|
|
|
|
def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs):
|
|
new_cls = super().__new__(cls, name, bases, attrs)
|
|
|
|
new_cls._is_abstract = is_abstract
|
|
|
|
return new_cls
|
|
|
|
@property
|
|
def is_abstract(cls):
|
|
return cls._is_abstract
|
|
|
|
|
|
class NumericMeta(DslType):
|
|
"""Metaclass for numeric types providing width and numpy dtype information.
|
|
|
|
:param width: Bit width of the numeric type, defaults to 8
|
|
:type width: int
|
|
:param np_dtype: Corresponding NumPy dtype
|
|
:type np_dtype: numpy.dtype, optional
|
|
:param mlir_type: Corresponding MLIR type
|
|
:type mlir_type: Any, optional
|
|
:param is_abstract: Whether the type is abstract, defaults to False
|
|
:type is_abstract: bool, optional
|
|
|
|
:ivar width: Bit width of the numeric type
|
|
:type width: int
|
|
:ivar _np_dtype: Corresponding NumPy dtype
|
|
:type _np_dtype: Union[numpy.dtype, None]
|
|
|
|
:property numpy_dtype: Returns the corresponding NumPy dtype
|
|
:rtype numpy_dtype: numpy.dtype
|
|
"""
|
|
|
|
width: int
|
|
|
|
# Placeholder type
|
|
_mlir_type = Any
|
|
_np_dtype: Union[np.dtype, None]
|
|
|
|
def __new__(
|
|
cls,
|
|
name,
|
|
bases,
|
|
attrs,
|
|
width=8,
|
|
np_dtype=None,
|
|
mlir_type=None,
|
|
is_abstract=False,
|
|
**kwargs,
|
|
):
|
|
def _extract_mlir_values(self):
|
|
return [self.ir_value()]
|
|
|
|
def _new_from_mlir_values(self, values: list) -> "Numeric":
|
|
res_ty = type(self)
|
|
return res_ty(values[0])
|
|
|
|
new_attrs = {
|
|
"__extract_mlir_values__": _extract_mlir_values,
|
|
"__new_from_mlir_values__": _new_from_mlir_values,
|
|
}
|
|
new_cls = super().__new__(
|
|
cls,
|
|
name,
|
|
bases,
|
|
new_attrs | attrs,
|
|
is_abstract=is_abstract,
|
|
**kwargs,
|
|
)
|
|
|
|
if mlir_type is not None:
|
|
new_cls._mlir_type = staticmethod(mlir_type)
|
|
|
|
new_cls.width = width
|
|
new_cls._np_dtype = np_dtype
|
|
return new_cls
|
|
|
|
@property
|
|
def numpy_dtype(cls):
|
|
return cls._np_dtype
|
|
|
|
@property
|
|
def is_integer(cls) -> bool: ...
|
|
|
|
@property
|
|
def is_float(cls) -> bool: ...
|
|
|
|
def is_same_kind(cls, other: Type) -> bool:
|
|
return cls.is_integer == other.is_integer or cls.is_float == other.is_float
|
|
|
|
@staticmethod
|
|
def from_python(value: Any) -> Type["Numeric"]:
|
|
"""
|
|
Deduce the DSL type from a Python value.
|
|
"""
|
|
if isinstance(value, int):
|
|
return Int32
|
|
elif isinstance(value, float):
|
|
return Float32
|
|
elif isinstance(value, bool):
|
|
return Boolean
|
|
raise DSLRuntimeError(
|
|
f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}"
|
|
)
|
|
|
|
@property
|
|
def mlir_type(cls):
|
|
return cls._mlir_type() # type: ignore
|
|
|
|
|
|
Value = TypeVar("Value")
|
|
|
|
|
|
def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric":
|
|
"""Cast an object to the specified numeric type.
|
|
|
|
:param obj: Object to be cast
|
|
:type obj: Union[bool, int, float, Value]
|
|
:param type_: Target numeric type
|
|
:type type_: Type[Numeric]
|
|
:raises TypeError: If casting to an abstract type or unsupported type conversion
|
|
:return: Object cast to the target numeric type
|
|
:rtype: Numeric
|
|
|
|
Example::
|
|
>>> x = cast(5, Int32) # Cast integer to Int32
|
|
>>> y = cast(3.14, Float32) # Cast float to Float32
|
|
"""
|
|
if type_.is_abstract:
|
|
if not isinstance(obj, type_):
|
|
raise TypeError(
|
|
f"can't cast {obj} to {type_}. Pass in concrete type instead, "
|
|
"e.g. Int32, Float32, etc."
|
|
)
|
|
# If target_type is abstract, and value is instance of target_type,
|
|
# then we can return value as is
|
|
else:
|
|
# Implicit cast based on using annotation type
|
|
obj = type_(obj)
|
|
return obj
|
|
|
|
|
|
# Option 1: use ir.Value as base
|
|
# class IntegerMeta(DslType, type(ir.Value)):
|
|
class IntegerMeta(NumericMeta):
|
|
"""Metaclass for integer types providing signedness information.
|
|
|
|
:param width: Bit width of the integer type, defaults to 32
|
|
:type width: int
|
|
:param signed: Whether the integer type is signed, defaults to True
|
|
:type signed: bool
|
|
:param mlir_type: Corresponding MLIR type, defaults to None
|
|
:type mlir_type: Any, optional
|
|
|
|
:ivar signed: Whether the integer type is signed
|
|
:vartype signed: bool
|
|
:ivar arith: Arithmetic operations interface
|
|
:vartype arith: Any
|
|
"""
|
|
|
|
signed: bool
|
|
|
|
def __new__(
|
|
cls,
|
|
name,
|
|
bases,
|
|
attrs,
|
|
width=32,
|
|
signed=True,
|
|
mlir_type=None,
|
|
is_abstract=False,
|
|
):
|
|
if width == 1:
|
|
np_dtype = np.bool_
|
|
elif width == 128:
|
|
np_dtype = None
|
|
elif signed:
|
|
np_dtype = getattr(np, f"int{width}")
|
|
else:
|
|
np_dtype = getattr(np, f"uint{width}")
|
|
|
|
def _c_pointers(self):
|
|
if width == 1:
|
|
c_value = ctypes.c_bool(self.value)
|
|
elif signed:
|
|
c_value = getattr(ctypes, f"c_int{width}")(self.value)
|
|
else:
|
|
c_value = getattr(ctypes, f"c_uint{width}")(self.value)
|
|
|
|
return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)]
|
|
|
|
new_attrs = {
|
|
"__c_pointers__": _c_pointers,
|
|
}
|
|
new_cls = super().__new__(
|
|
cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract
|
|
)
|
|
new_cls.signed = signed
|
|
return new_cls
|
|
|
|
def __str__(cls):
|
|
return f"{cls.__name__}"
|
|
|
|
@property
|
|
def is_integer(cls) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def is_float(cls) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def zero(cls) -> int:
|
|
return 0
|
|
|
|
@property
|
|
def min(cls) -> int:
|
|
if cls.signed:
|
|
return -(2 ** (cls.width - 1))
|
|
else:
|
|
return 0
|
|
|
|
@property
|
|
def max(cls) -> int:
|
|
if cls.signed:
|
|
return 2 ** (cls.width - 1) - 1
|
|
else:
|
|
return 2**cls.width - 1
|
|
|
|
def recast_width(cls, width):
|
|
type_map = {
|
|
8: Int8,
|
|
16: Int16,
|
|
32: Int32,
|
|
64: Int64,
|
|
128: Int128,
|
|
}
|
|
if width not in type_map:
|
|
raise TypeError(f"Unsupported width: {width}")
|
|
return type_map[width]
|
|
|
|
|
|
class FloatMeta(NumericMeta):
|
|
"""Metaclass for floating-point types.
|
|
|
|
This metaclass provides type system infrastructure for floating-point types in the DSL,
|
|
handling MLIR type mappings and NumPy type conversions.
|
|
|
|
:param width: Bit width of the float type, defaults to 32
|
|
:type width: int
|
|
:param mlir_type: Corresponding MLIR type, defaults to None
|
|
:type mlir_type: Any, optional
|
|
:param is_abstract: Whether this is an abstract base class, defaults to False
|
|
:type is_abstract: bool, optional
|
|
|
|
:ivar _arith: Arithmetic operations interface
|
|
:vartype _arith: Any
|
|
"""
|
|
|
|
_exponent_width: int
|
|
_mantissa_width: int
|
|
|
|
def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False):
|
|
np_dtype = getattr(np, name.lower(), None)
|
|
new_cls = super().__new__(
|
|
cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract
|
|
)
|
|
# Extract exponent and mantissa bits from class name if it follows Float<E><M> pattern
|
|
# For example: Float8E4M3 -> exponent_width=4, mantissa_width=3
|
|
import re
|
|
|
|
if not is_abstract:
|
|
match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name)
|
|
if match:
|
|
exp_bits = int(match.group(2))
|
|
mant_bits = int(match.group(3))
|
|
|
|
# Store extracted values as class attributes
|
|
new_cls._exponent_width = exp_bits
|
|
new_cls._mantissa_width = mant_bits
|
|
# Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc.
|
|
return new_cls
|
|
|
|
def __str__(cls):
|
|
return f"{cls.__name__}"
|
|
|
|
@property
|
|
def is_integer(cls) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def is_float(cls) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def zero(cls) -> float:
|
|
return 0.0
|
|
|
|
@property
|
|
def inf(cls) -> float:
|
|
return float("inf")
|
|
|
|
@property
|
|
def nan(cls) -> float:
|
|
return float("nan")
|
|
|
|
@property
|
|
def exponent_width(cls) -> int:
|
|
return cls._exponent_width
|
|
|
|
@property
|
|
def mantissa_width(cls) -> int:
|
|
return cls._mantissa_width
|
|
|
|
def recast_width(cls, width):
|
|
type_map = {
|
|
16: Float16,
|
|
32: Float32,
|
|
64: Float64,
|
|
}
|
|
if width not in type_map:
|
|
raise TypeError(f"Unsupported width: {width}")
|
|
return type_map[width]
|
|
|
|
|
|
def _arith_signless_to_int(a, target_type):
|
|
# is_signed: sign of result type
|
|
if target_type.width > a.type.width:
|
|
# arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
|
|
if target_type.signed and a.type.width > 1:
|
|
return arith.extsi(target_type.mlir_type, a)
|
|
else:
|
|
return arith.extui(target_type.mlir_type, a)
|
|
elif target_type.width < a.type.width:
|
|
return arith.trunci(target_type.mlir_type, a)
|
|
else:
|
|
return a
|
|
|
|
|
|
def _binary_op_type_promote(a, b, promote_bool: bool = False):
|
|
"""Promote two numeric operands following type promotion rules.
|
|
|
|
:param a: First numeric operand
|
|
:type a: Numeric
|
|
:param b: Second numeric operand
|
|
:type b: Numeric
|
|
:param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False
|
|
:type promote_bool: bool, optional
|
|
:raises ValueError: If implicit float promotion is not supported between the given types
|
|
:return: Tuple containing promoted operands and their resulting type
|
|
:rtype: tuple[Numeric, Numeric, Type[Numeric]]
|
|
|
|
Type promotion rules:
|
|
1. If operands are same type and not bools needing promotion:
|
|
- No promotion needed, return original types
|
|
2. If either operand is float:
|
|
a. If one is float and one is int:
|
|
- Convert int to the float type
|
|
b. If both are float:
|
|
- Promote to higher precision float if width >= 16
|
|
- For same width, promote to more general type (Float32 over TFloat32)
|
|
- Otherwise raise ValueError for unsupported promotion
|
|
3. Otherwise, both operands are integers. Integer promotion rules:
|
|
a. If promote_bool is True and either operand is bool:
|
|
- Promote bool to Int32 for arithmetic operations
|
|
|
|
Exceptions for numpy dtype casting:
|
|
- array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_)
|
|
|
|
What is not supported:
|
|
- promotion with narrow precision float types which requires explicit cast by user
|
|
"""
|
|
a_type = a.dtype
|
|
b_type = b.dtype
|
|
|
|
# Early return for same types (except when they're bools that need promotion)
|
|
if a_type == b_type and not (promote_bool and a_type is Boolean):
|
|
return a, b, a_type
|
|
|
|
# Handle floating point promotions
|
|
if a_type.is_float or b_type.is_float:
|
|
# Get highest precision float type based on bitwidth
|
|
a_width = getattr(a_type, "width", 0)
|
|
b_width = getattr(b_type, "width", 0)
|
|
|
|
# If one type is integer, convert it to the float type
|
|
if a_type.is_float and not b_type.is_float:
|
|
b_type = a_type.recast_width(max(a_width, b_width))
|
|
elif b_type.is_float and not a_type.is_float:
|
|
a_type = b_type.recast_width(max(a_width, b_width))
|
|
|
|
# Both are float types - handle precision promotion
|
|
if a_width > b_width and a_width >= 16:
|
|
res_type = a_type
|
|
elif b_width > a_width and b_width >= 16:
|
|
res_type = b_type
|
|
elif a_width == b_width:
|
|
# Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16
|
|
if a_type is Float64 or b_type is Float64:
|
|
res_type = Float64
|
|
elif a_type is Float32 or b_type is Float32:
|
|
res_type = Float32
|
|
elif a_type is Float16 or b_type is Float16:
|
|
res_type = Float16
|
|
else:
|
|
raise ValueError(
|
|
f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
|
|
)
|
|
|
|
# Only convert if type is different
|
|
new_a = a.to(res_type) if a.dtype != res_type else a
|
|
new_b = b.to(res_type) if b.dtype != res_type else b
|
|
return new_a, new_b, res_type
|
|
|
|
# Handle bool promotion for arithmetic operations
|
|
if promote_bool:
|
|
if a_type is Boolean and b_type is Boolean:
|
|
# Only promote to Int32 when both are bool
|
|
a = a.to(Int32)
|
|
b = b.to(Int32)
|
|
a_type = b_type = a.dtype
|
|
|
|
# If both were bools, they're now same type (Int32)
|
|
if a_type == b_type:
|
|
return a, b, a_type
|
|
|
|
# Same type, no promotion needed
|
|
if a_type == b_type:
|
|
return a, b, a_type
|
|
|
|
a_signed = a_type.signed
|
|
b_signed = b_type.signed
|
|
a_width = a_type.width
|
|
b_width = b_type.width
|
|
|
|
# Mixed signedness case
|
|
if a_signed != b_signed:
|
|
unsigned_type = a_type if not a_signed else b_type
|
|
signed_type = a_type if a_signed else b_type
|
|
unsigned_width = a_width if not a_signed else b_width
|
|
|
|
if unsigned_width >= signed_type.width:
|
|
# Promote both to unsigned of larger width
|
|
res_type = unsigned_type
|
|
else:
|
|
# Promote both to signed of larger width
|
|
res_type = signed_type
|
|
|
|
new_a = a.to(res_type) if a.dtype != res_type else a
|
|
new_b = b.to(res_type) if b.dtype != res_type else b
|
|
return new_a, new_b, res_type
|
|
|
|
# Same signedness, different width - promote to larger width
|
|
if a_width >= b_width:
|
|
return a, b.to(a.dtype), a.dtype
|
|
else:
|
|
return a.to(b.dtype), b, b.dtype
|
|
|
|
|
|
def _binary_op(op, promote_operand=True, promote_bool=False, flip=False):
|
|
"""Wrapper for binary operations on Numeric types.
|
|
|
|
This wrapper handles type promotion, operation execution, and result type determination
|
|
for binary operations between Numeric types.
|
|
|
|
:param op: The binary operation to perform (e.g., operator.add, operator.sub)
|
|
:type op: callable
|
|
:param emitter: Function that emits the MLIR operation for dynamic values
|
|
:type emitter: callable
|
|
:param promote_operand: Whether to promote operands to the same type, defaults to True
|
|
:type promote_operand: bool, optional
|
|
:param promote_bool: Whether to promote boolean results to Boolean type, defaults to False
|
|
:type promote_bool: bool, optional
|
|
:param flip: Whether to flip the operands when calling the operation, defaults to False
|
|
:type flip: bool, optional
|
|
|
|
:raises TypeError: When an unsupported operation is attempted on specific numeric types
|
|
|
|
.. note::
|
|
Not all operations are supported for all numeric types. In particular:
|
|
|
|
- Subtraction is not fully supported for Integer types
|
|
- Multiplication, floor division, and modulo operations may have limited support
|
|
- Division (truediv) with integer types is not fully supported and converts to Float32
|
|
"""
|
|
|
|
def wrapper(lhs, rhs, *, loc=None, ip=None):
|
|
orig_lhs_type = type(lhs)
|
|
orig_rhs_type = type(rhs)
|
|
|
|
# When called directly with self and other
|
|
ty = type(lhs)
|
|
# Canonicalize to Numeric type for promotion
|
|
if not isinstance(rhs, Numeric):
|
|
if not isinstance(rhs, (ArithValue, int, float, bool)):
|
|
# This allows rhs class to implement __rmul__
|
|
return NotImplemented
|
|
|
|
if isinstance(rhs, ArithValue):
|
|
if isinstance(rhs.type, ir.VectorType):
|
|
return NotImplemented
|
|
|
|
rhs = as_numeric(rhs)
|
|
|
|
# default result type to left-hand-side
|
|
res_type = ty
|
|
|
|
if promote_operand:
|
|
lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool)
|
|
else:
|
|
rhs = ty(rhs)
|
|
|
|
if op in (
|
|
operator.lt,
|
|
operator.le,
|
|
operator.gt,
|
|
operator.ge,
|
|
operator.eq,
|
|
operator.ne,
|
|
):
|
|
res_type = Boolean
|
|
elif op == operator.truediv and isinstance(lhs, Integer):
|
|
res_type = Float32
|
|
elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean:
|
|
res_type = Boolean
|
|
|
|
if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer):
|
|
lhs_val = lhs.value.with_signedness(lhs.signed)
|
|
else:
|
|
lhs_val = lhs.value
|
|
|
|
if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer):
|
|
rhs_val = rhs.value.with_signedness(rhs.signed)
|
|
else:
|
|
rhs_val = rhs.value
|
|
|
|
if flip:
|
|
lhs_val, rhs_val = rhs_val, lhs_val
|
|
|
|
# Check if the operation is supported by the operands
|
|
res_val = op(lhs_val, rhs_val)
|
|
return res_type(res_val, loc=loc, ip=ip)
|
|
|
|
return wrapper
|
|
|
|
|
|
class Numeric(metaclass=NumericMeta, is_abstract=True):
|
|
"""Base class for all numeric types in the DSL.
|
|
|
|
This class provides the foundation for both Integer and Float types,
|
|
implementing basic arithmetic operations.
|
|
|
|
:param value: The value to store in the numeric type
|
|
:type value: Union[bool, int, float, Value]
|
|
|
|
:ivar value: The stored numeric value
|
|
:vartype value: Union[bool, int, float, Value]
|
|
"""
|
|
|
|
def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None):
|
|
self.value = value
|
|
|
|
def __str__(self) -> str:
|
|
# Use member's pretty-str method if member object has method.
|
|
# This can be extended in future to have better support for IDE, jupyter notebook, etc.
|
|
pretty_str = getattr(self.value, "pretty_str", None)
|
|
if pretty_str is not None:
|
|
return pretty_str()
|
|
else:
|
|
return "?"
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}({repr(self.value)})"
|
|
|
|
def __hash__(self):
|
|
return hash(type(self).__class__) ^ hash(self.value)
|
|
|
|
@property
|
|
def dtype(self) -> Type["Numeric"]:
|
|
return type(self)
|
|
|
|
@overload
|
|
def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ...
|
|
|
|
@overload
|
|
def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ...
|
|
|
|
@overload
|
|
def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ...
|
|
|
|
@overload
|
|
def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ...
|
|
|
|
@overload
|
|
def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ...
|
|
|
|
def to(self, dtype: Type, *, loc=None, ip=None):
|
|
"""Convert this numeric value to another numeric type.
|
|
|
|
If the target type is the same as the current type, returns self.
|
|
Otherwise, creates a new instance of the target type with the same value.
|
|
|
|
:param dtype: The target numeric type to convert to
|
|
:type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]]
|
|
:return: A new instance of the target type, or self if types match
|
|
:rtype: Numeric
|
|
:raises TypeError: If trying to convert an MLIR value to a static Python type
|
|
:raises TypeError: If trying to convert to unsupported float types like Float8E4M3,
|
|
Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN
|
|
|
|
.. note::
|
|
|
|
Unsupported destination float types:
|
|
- Float8E4M3
|
|
- Float8E4M3B11FNUZ
|
|
- Float4E2M1FN
|
|
- Float6E3M2FN
|
|
- Float6E2M3FN
|
|
|
|
Example::
|
|
|
|
.. code-block:: python
|
|
|
|
# Convert between DSL numeric types
|
|
x = Int32(5)
|
|
y = x.to(Float32) # Converts to Float32(5.0)
|
|
|
|
# Convert to Python primitive types
|
|
# They are considered as static values at JIT time
|
|
z = x.to(int) # Returns Python int 5
|
|
w = y.to(float) # Returns Python float 5.0
|
|
|
|
# This will raise a ValueError
|
|
mlir_val = arith.constant(T.i32(), 42)
|
|
num = Int32(mlir_val)
|
|
num.to(int) # ValueError: unable to convert MLIR value to static type: <class 'int'>
|
|
"""
|
|
if dtype in _unsupported_dst_float_types:
|
|
raise TypeError(f"Unsupported destination float type: {dtype}")
|
|
|
|
if isinstance(dtype, type(self)):
|
|
return self
|
|
elif isinstance(dtype, NumericMeta):
|
|
return dtype(self)
|
|
elif dtype is ir.Value:
|
|
if isinstance(self.value, (int, float, bool)):
|
|
res = arith_helper.const(
|
|
self.value, self.dtype.mlir_type, loc=loc, ip=ip
|
|
)
|
|
elif isinstance(self.value, ir.Value):
|
|
res = self.value
|
|
else:
|
|
raise ValueError(
|
|
f"cannot convert {type(self)} to {dtype}, "
|
|
f"self.value is {self.value.type}"
|
|
)
|
|
|
|
if not isinstance(res, ArithValue):
|
|
raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}")
|
|
|
|
return res.with_signedness(getattr(type(self), "signed", None))
|
|
elif dtype in (int, float, bool):
|
|
if isinstance(self.value, ir.Value):
|
|
raise ValueError(
|
|
f"unable to convert {self.value} to static type: {dtype}"
|
|
)
|
|
return dtype(self.value)
|
|
else:
|
|
raise ValueError(f"unable to convert {type(self)} to {dtype}")
|
|
|
|
def ir_value(self, *, loc=None, ip=None) -> ir.Value:
|
|
return self.to(ir.Value, loc=loc, ip=ip)
|
|
|
|
@property
|
|
def zero(self) -> "Numeric": ...
|
|
|
|
def __dsl_not__(self, *, loc=None, ip=None):
|
|
"""DSL implementation of Python's `not` operator.
|
|
|
|
Returns True if the value is equal to zero, False otherwise.
|
|
This matches Python's behavior where any non-zero number is considered True.
|
|
|
|
:param loc: The source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: The insertion point for the operation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: The result of the logical not operation
|
|
:rtype: Boolean
|
|
"""
|
|
if isinstance(self.value, (int, float, bool)):
|
|
return not self.value
|
|
else:
|
|
ty = type(self)
|
|
zero_val = arith.constant(ty.mlir_type, ty.zero)
|
|
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
|
|
|
|
def __dsl_and__(self, other, *, loc=None, ip=None):
|
|
"""DSL implementation of Python's `and` operator.
|
|
|
|
Returns the second operand if the first is truthy, otherwise returns the first operand.
|
|
A numeric value is considered truthy if it is non-zero.
|
|
|
|
:param other: The right-hand operand
|
|
:type other: Numeric
|
|
:param loc: The source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: The insertion point for the operation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: The result of the logical and operation
|
|
:rtype: Boolean
|
|
|
|
Example::
|
|
|
|
5 and 3 -> 3
|
|
0 and 3 -> 0
|
|
3 and 0 and ... -> 0
|
|
"""
|
|
is_true = self.__dsl_bool__(loc=loc, ip=ip)
|
|
|
|
def and_op(lhs, rhs):
|
|
if isinstance(lhs, (int, float, bool)):
|
|
if isinstance(rhs, (int, float, bool)):
|
|
return lhs and rhs
|
|
else:
|
|
lhs = arith.constant(rhs.type, lhs)
|
|
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
|
else:
|
|
if isinstance(rhs, (int, float, bool)):
|
|
rhs = arith.constant(lhs.type, rhs)
|
|
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
|
|
|
return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __dsl_or__(self, other, *, loc=None, ip=None):
|
|
"""DSL implementation of Python's `or` operator.
|
|
|
|
Returns the first operand if it is truthy, otherwise returns the second operand.
|
|
A numeric value is considered truthy if it is non-zero.
|
|
|
|
:param other: The right-hand operand
|
|
:type other: Numeric
|
|
:param loc: The source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: The insertion point for the operation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: The result of the logical or operation
|
|
:rtype: Boolean
|
|
|
|
Example::
|
|
|
|
5 or 3 -> 5
|
|
0 or 3 -> 3
|
|
3 or 0 -> 3
|
|
"""
|
|
is_true = self.__dsl_bool__(loc=loc, ip=ip)
|
|
|
|
def or_op(lhs, rhs):
|
|
if isinstance(lhs, (int, float, bool)):
|
|
if isinstance(rhs, (int, float, bool)):
|
|
return lhs or rhs
|
|
else:
|
|
lhs = arith.constant(rhs.type, lhs)
|
|
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
if isinstance(rhs, (int, float, bool)):
|
|
rhs = arith.constant(lhs.type, rhs)
|
|
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
|
|
|
return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean":
|
|
"""DSL implementation of Python's __bool__ method.
|
|
|
|
Returns a Boolean indicating whether this value is considered truthy.
|
|
For numeric types, returns True if the value is non-zero.
|
|
|
|
:param loc: The source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: The insertion point for the operation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: True if this value is truthy (non-zero), False otherwise
|
|
:rtype: Boolean
|
|
"""
|
|
zero = type(self).zero
|
|
return self.__ne__(zero, loc=loc, ip=ip)
|
|
|
|
def __bool__(self):
|
|
if isinstance(self.value, (int, float, bool)):
|
|
return bool(self.value)
|
|
else:
|
|
raise DSLRuntimeError(
|
|
f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.",
|
|
suggestion=[
|
|
"Decorate the parent function with `jit` decorator and with `preprocess` enabled.",
|
|
"Ensure not using patterns that DSL does not support.",
|
|
"Otherwise, please file a bug report.",
|
|
],
|
|
)
|
|
|
|
def __index__(self):
|
|
if isinstance(self.value, (int, float, bool)):
|
|
return self.value
|
|
else:
|
|
raise DSLRuntimeError(
|
|
f"'{type(self.value)}' object cannot be interpreted as an integer",
|
|
suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator",
|
|
)
|
|
|
|
def __neg__(self, *, loc=None, ip=None):
|
|
if isinstance(self, (bool, int, float)):
|
|
return type(self)(-self.value) # type: ignore
|
|
else:
|
|
return type(self)(-self.value, loc=loc, ip=ip) # type: ignore
|
|
|
|
@staticmethod
|
|
def _from_python_value(value):
|
|
if isinstance(value, Numeric):
|
|
return value
|
|
|
|
if isinstance(value, bool):
|
|
res_type = Boolean
|
|
elif isinstance(value, int):
|
|
res_type = Int32
|
|
elif isinstance(value, float):
|
|
res_type = Float32
|
|
elif isinstance(value, ArithValue):
|
|
res_type = Numeric.from_mlir_type(value.type)
|
|
else:
|
|
raise ValueError(
|
|
f"unable to convert {value} in type {type(value)} to Numeric"
|
|
)
|
|
return res_type(value)
|
|
|
|
def __add__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __sub__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __mul__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.floordiv, promote_bool=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.truediv, promote_bool=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __mod__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip)
|
|
|
|
def __radd__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return self.__add__(other, loc=loc, ip=ip)
|
|
|
|
def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.sub, promote_bool=True, flip=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return self.__mul__(other, loc=loc, ip=ip)
|
|
|
|
def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.floordiv, promote_bool=True, flip=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.truediv, promote_bool=True, flip=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.mod, promote_bool=True, flip=True)(
|
|
self, other, loc=loc, ip=ip
|
|
)
|
|
|
|
def __eq__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __ne__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __lt__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __le__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __gt__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __ge__(self, other, *, loc=None, ip=None) -> "Boolean":
|
|
return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __pow__(self, other, *, loc=None, ip=None) -> "Numeric":
|
|
return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore
|
|
|
|
def __c_pointers__(self):
|
|
raise ValueError(
|
|
f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}"
|
|
)
|
|
|
|
def __get_mlir_types__(self):
|
|
return [type(self).mlir_type]
|
|
|
|
@staticmethod
|
|
def from_mlir_type(mlir_type):
|
|
type_map = {
|
|
T.bool(): Boolean,
|
|
T.f64(): Float64,
|
|
T.f32(): Float32,
|
|
T.tf32(): TFloat32,
|
|
T.f16(): Float16,
|
|
T.bf16(): BFloat16,
|
|
T.i(128): Int128,
|
|
T.i64(): Int64,
|
|
T.i32(): Int32,
|
|
T.i16(): Int16,
|
|
T.i8(): Int8,
|
|
T.si(128): Int128,
|
|
T.si64(): Int64,
|
|
T.si32(): Int32,
|
|
T.si16(): Int16,
|
|
T.si8(): Int8,
|
|
T.ui(128): Uint128,
|
|
T.ui64(): Uint64,
|
|
T.ui32(): Uint32,
|
|
T.ui16(): Uint16,
|
|
T.ui8(): Uint8,
|
|
T.f8E5M2(): Float8E5M2,
|
|
T.f8E4M3(): Float8E4M3,
|
|
T.f8E4M3FN(): Float8E4M3FN,
|
|
T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ,
|
|
T.f4E2M1FN(): Float4E2M1FN,
|
|
T.f6E2M3FN(): Float6E2M3FN,
|
|
T.f6E3M2FN(): Float6E3M2FN,
|
|
T.f8E8M0FNU(): Float8E8M0FNU,
|
|
}
|
|
if mlir_type not in type_map:
|
|
raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}")
|
|
return type_map[mlir_type]
|
|
|
|
|
|
def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric:
|
|
"""Convert a Python primitive value to a Numeric type.
|
|
|
|
:param obj: Python primitive value to convert
|
|
:type obj: Union[bool, int, float]
|
|
:return: The converted Numeric object
|
|
:rtype: Numeric
|
|
|
|
Example::
|
|
|
|
.. code-block:: python
|
|
|
|
x = as_numeric(5) # Converts to Int32
|
|
y = as_numeric(3.14) # Converts to Float32
|
|
z = as_numeric(True) # Converts to Boolean
|
|
"""
|
|
if isinstance(obj, Numeric):
|
|
return obj
|
|
return Numeric._from_python_value(obj)
|
|
|
|
|
|
class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True):
|
|
"""A class representing integer values with specific width and signedness.
|
|
|
|
This class provides functionality to create and manipulate integer values with
|
|
configurable width and signedness. It supports conversion from various input types
|
|
including Python scalars, MLIR Values, and other numeric types.
|
|
|
|
:param x: The input value to convert to this integer type
|
|
:type x: Union[bool, int, float, ir.Value, Integer, Float]
|
|
|
|
:return: A new Integer instance with the converted value
|
|
:rtype: Integer
|
|
|
|
:raises AssertionError: If the type's numpy_dtype is None
|
|
:raises NotImplementedError: If converting between different Integer types
|
|
:raises ValueError: If the input type is not supported for conversion
|
|
:raises OverflowError: If converting float infinity to integer
|
|
|
|
Type conversion behavior:
|
|
|
|
* Python scalars (bool, int, float):
|
|
* Converted through numpy dtype casting
|
|
* NaN and infinity values are rejected
|
|
* Example: Int8(256) -> -256 (overflow behavior)
|
|
|
|
* MLIR Value with IntegerType:
|
|
* Width differences handled by signless to signed/unsigned conversion
|
|
* Example: i8 -> i8/ui8 depending on target type
|
|
|
|
* MLIR Value with FloatType:
|
|
* Uses MLIR float-to-int conversion
|
|
* NaN and infinity values is undefined behavior
|
|
* Example: f32 -> i32/ui32 depending on target type
|
|
|
|
* Integer:
|
|
* Uses MLIR float-to-int conversion or numpy dtype casting
|
|
* Example: Int32(Int32(5)) => 5
|
|
|
|
* Float:
|
|
* Uses MLIR float-to-int conversion
|
|
* Example: Int32(Float(5.7)) -> 5
|
|
|
|
Example usage:
|
|
|
|
.. code-block:: python
|
|
|
|
x = Int32(5) # From integer
|
|
y = Int32(True) # From boolean
|
|
z = Int32(3.7) # From float (truncates)
|
|
w = Int32(x) # From same Integer type
|
|
c5 = arith.constant(5, T.i32())
|
|
a = Int32(c5) # Treat c5 as int32 bitwise
|
|
"""
|
|
|
|
def __init__(self, x, *, loc=None, ip=None):
|
|
ty = type(self)
|
|
|
|
if isinstance(x, (bool, int, float)):
|
|
# Add check for NaN before numpy conversion
|
|
if isinstance(x, float):
|
|
if np.isnan(x):
|
|
raise ValueError("Cannot convert float NaN to integer")
|
|
elif np.isinf(x):
|
|
raise OverflowError("Cannot convert float infinity to integer")
|
|
|
|
np_dtype = ty.numpy_dtype
|
|
assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
|
|
x_val = int(np.array(x).astype(np_dtype))
|
|
elif type(x) == ty:
|
|
x_val = x.value
|
|
elif isinstance(x, ir.Value): # type: ignore
|
|
x_val = x
|
|
if isinstance(x.type, ir.IntegerType): # type: ignore
|
|
if x.type.width != ty.width:
|
|
# signless -> (u)int
|
|
x_val = _arith_signless_to_int(x, ty)
|
|
elif isinstance(x.type, ir.FloatType): # type: ignore
|
|
# float -> (u)int
|
|
x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip)
|
|
elif isinstance(x, Integer):
|
|
if isinstance(x.value, ir.Value):
|
|
x_val = arith_helper.int_to_int(x.ir_value(), ty)
|
|
else:
|
|
# For non-MLIR values, use numpy casting
|
|
src_val = np.array(x.value, dtype=type(x).numpy_dtype)
|
|
x_val = int(src_val.astype(ty.numpy_dtype))
|
|
elif isinstance(x, Float):
|
|
# float -> int is handled by Integer.__init__ recursively
|
|
Integer.__init__(self, x.value)
|
|
return
|
|
else:
|
|
raise DSLRuntimeError(f"{x} to integer conversion is not supported")
|
|
|
|
super().__init__(x_val)
|
|
|
|
def __invert__(self, *, loc=None, ip=None):
|
|
res_type = type(self)
|
|
return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip))
|
|
|
|
def __lshift__(self, other, *, loc=None, ip=None):
|
|
return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip)
|
|
|
|
def __rlshift__(self, other, *, loc=None, ip=None):
|
|
other_ = as_numeric(other)
|
|
if not isinstance(other_, Integer):
|
|
raise ValueError(f"Cannot left shift {other_} with {self}")
|
|
return other_.__lshift__(self, loc=loc, ip=ip)
|
|
|
|
def __rshift__(self, other, *, loc=None, ip=None):
|
|
return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip)
|
|
|
|
def __rrshift__(self, other, *, loc=None, ip=None):
|
|
other_ = as_numeric(other)
|
|
if not isinstance(other_, Integer):
|
|
raise ValueError(f"Cannot right shift {other_} with {self}")
|
|
return other_.__rshift__(self, loc=loc, ip=ip)
|
|
|
|
def __and__(self, other, *, loc=None, ip=None):
|
|
return _binary_op(operator.and_)(self, other, loc=loc, ip=ip)
|
|
|
|
def __rand__(self, other, *, loc=None, ip=None):
|
|
return self.__and__(other, loc=loc, ip=ip)
|
|
|
|
def __or__(self, other, *, loc=None, ip=None):
|
|
return _binary_op(operator.or_)(self, other, loc=loc, ip=ip)
|
|
|
|
def __ror__(self, other, *, loc=None, ip=None):
|
|
return self.__or__(other, loc=loc, ip=ip)
|
|
|
|
def __xor__(self, other, *, loc=None, ip=None):
|
|
return _binary_op(operator.xor)(self, other, loc=loc, ip=ip)
|
|
|
|
def __rxor__(self, other, *, loc=None, ip=None):
|
|
return self.__xor__(other, loc=loc, ip=ip)
|
|
|
|
|
|
class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True):
|
|
"""A class representing floating-point values.
|
|
|
|
:param x: The input value to convert to this float type.
|
|
:type x: Union[bool, int, float, ir.Value, Integer, Float]
|
|
|
|
Type conversion behavior:
|
|
|
|
1. Python scalars (bool, int, float):
|
|
- Converted through numpy dtype casting
|
|
- Example: Float32(1.7) -> 1.7
|
|
|
|
2. MLIR Value with FloatType:
|
|
- If width differs: converts between float types
|
|
- Example: f16 -> f32
|
|
|
|
3. MLIR Value with IntegerType:
|
|
- Not supported, raises ValueError
|
|
|
|
4. Integer:
|
|
- Converts using MLIR int-to-float operation
|
|
- Example: Float32(Int32(5)) -> 5.0
|
|
|
|
5. Float:
|
|
- Direct conversion between float types
|
|
- Example: Float32(Float32(1.5)) -> 1.5
|
|
|
|
.. note::
|
|
The following narrow precision types are only supported in device code:
|
|
|
|
8-bit float types:
|
|
- Float8E5M2
|
|
- Float8E4M3
|
|
- Float8E4M3FN
|
|
- Float8E8M0FNU
|
|
- Float8E4M3B11FNUZ
|
|
|
|
6-bit float types:
|
|
- Float6E3M2FN
|
|
- Float6E2M3FN
|
|
|
|
4-bit float types:
|
|
- Float4E2M1FN
|
|
|
|
Narrow precision types and special floating-point formats support matrix on device:
|
|
|
|
:raises AssertionError: If the type's numpy_dtype is None
|
|
:raises ValueError: If conversion from the input type is not supported
|
|
"""
|
|
|
|
def __init__(self, x, *, loc=None, ip=None):
|
|
ty = type(self)
|
|
|
|
if isinstance(x, (bool, int, float)): # type: ignore
|
|
# Why we need to convert x to with numpy?
|
|
# np_dtype = ty.numpy_dtype
|
|
# assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
|
|
# x = float(np.array(x).astype(np_dtype))
|
|
super().__init__(float(x))
|
|
elif isinstance(x, ir.Value): # type: ignore
|
|
if isinstance(x.type, ir.IntegerType): # type: ignore
|
|
raise DSLRuntimeError("signless to float conversion is not implemented")
|
|
elif isinstance(x.type, ir.FloatType): # type: ignore
|
|
if x.type != ty.mlir_type:
|
|
x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip)
|
|
super().__init__(x)
|
|
elif isinstance(x, Integer):
|
|
if isinstance(x.value, ir.Value): # type: ignore
|
|
x = arith_helper.itofp(
|
|
x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip
|
|
)
|
|
else:
|
|
x = float(x.value)
|
|
super().__init__(x)
|
|
elif isinstance(x, Float):
|
|
Float.__init__(self, x.value)
|
|
else:
|
|
raise DSLRuntimeError(f"{x} to Float conversion is not supported")
|
|
|
|
|
|
class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool):
|
|
"""Boolean type representation in the DSL.
|
|
|
|
This class represents boolean values in the DSL, with a width of 1 bit.
|
|
It supports conversion from various types to boolean values.
|
|
|
|
:param a: Value to convert to Boolean
|
|
:type a: Union[bool, int, float, "Value", Numeric]
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location], optional
|
|
:param ip: Insertion point for MLIR operations, defaults to None
|
|
:type ip: Optional[InsertionPoint], optional
|
|
:raises DSLRuntimeError: If the input value cannot be converted to Boolean
|
|
|
|
Conversion rules:
|
|
|
|
1. Python bool/int/float:
|
|
- Converted using Python's bool() function
|
|
- Example: Boolean(1) -> True, Boolean(0) -> False
|
|
|
|
2. Numeric:
|
|
- Uses the Numeric.value to construct Boolean recursively
|
|
|
|
3. MLIR Value with IntegerType:
|
|
- If width is 1: Direct assignment
|
|
- Otherwise: Compares with 0 using arith.cmpi
|
|
|
|
4. MLIR Value with FloatType:
|
|
- Compares with 0.0 using arith.cmpf
|
|
- Uses unordered comparison to handle NaN values
|
|
"""
|
|
|
|
def __init__(
|
|
self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None
|
|
):
|
|
value = None
|
|
if isinstance(a, (bool, int, float)):
|
|
value = bool(a)
|
|
elif isinstance(a, Numeric):
|
|
Boolean.__init__(self, a.value, loc=loc, ip=ip)
|
|
return
|
|
elif isinstance(a, ArithValue):
|
|
if a.type == T.bool():
|
|
value = a
|
|
else:
|
|
value = a != arith_helper.const(0, a.type, loc=loc, ip=ip)
|
|
if value is None:
|
|
raise DSLRuntimeError(f"Cannot convert {a} to Boolean")
|
|
super().__init__(value, loc=loc, ip=ip)
|
|
self._value_int8 = None
|
|
|
|
def ir_value_int8(self, *, loc=None, ip=None):
|
|
"""
|
|
Returns int8 ir value of Boolean.
|
|
When we need to store Boolean tensor element, use ir_value_int8().
|
|
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location], optional
|
|
:param ip: Insertion point for MLIR operations, defaults to None
|
|
:type ip: Optional[InsertionPoint], optional
|
|
:return: The int8 value of this Boolean
|
|
:rtype: ir.Value
|
|
"""
|
|
if self._value_int8 is not None:
|
|
return self._value_int8
|
|
self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value()
|
|
return self._value_int8
|
|
|
|
def __neg__(self, *, loc=None, ip=None):
|
|
"""Negation operator is not supported for boolean type.
|
|
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location], optional
|
|
:param ip: Insertion point for MLIR operations, defaults to None
|
|
:type ip: Optional[InsertionPoint], optional
|
|
:raises TypeError: Always raises this error as negation is not supported
|
|
"""
|
|
raise TypeError("Negation, the operator `-` is not supported for boolean type")
|
|
|
|
|
|
class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ...
|
|
|
|
|
|
class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ...
|
|
|
|
|
|
class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ...
|
|
|
|
|
|
class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ...
|
|
|
|
|
|
class Int128(
|
|
Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128)
|
|
): ...
|
|
|
|
|
|
class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ...
|
|
|
|
|
|
class Uint16(
|
|
Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16
|
|
): ...
|
|
|
|
|
|
class Uint32(
|
|
Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32
|
|
): ...
|
|
|
|
|
|
class Uint64(
|
|
Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64
|
|
): ...
|
|
|
|
|
|
class Uint128(
|
|
Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128)
|
|
): ...
|
|
|
|
|
|
class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64):
|
|
def __c_pointers__(self):
|
|
if not isinstance(self.value, float):
|
|
raise ValueError("only float is supported")
|
|
|
|
return [
|
|
ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p)
|
|
]
|
|
|
|
|
|
class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32):
|
|
@staticmethod
|
|
def _get_c_pointer(value: float):
|
|
return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p)
|
|
|
|
def __c_pointers__(self):
|
|
if not isinstance(self.value, float):
|
|
raise ValueError("only float is supported")
|
|
|
|
return [Float32._get_c_pointer(self.value)]
|
|
|
|
|
|
class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32):
|
|
def __c_pointers__(self):
|
|
if not isinstance(self.value, float):
|
|
raise ValueError("only float is supported")
|
|
return [Float32._get_c_pointer(self.value)]
|
|
|
|
|
|
class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16):
|
|
@staticmethod
|
|
def _get_c_pointer(value: float):
|
|
# Convert float to float16 binary representation
|
|
# First convert to numpy float16 to handle the conversion
|
|
f16_val = np.float16(value)
|
|
# Get the raw bits as a 16-bit integer
|
|
bits = f16_val.view(np.uint16)
|
|
# Create a short (16-bit int) with those bits
|
|
c_val = ctypes.c_short(bits)
|
|
return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p)
|
|
|
|
def __c_pointers__(self):
|
|
if not isinstance(self.value, float):
|
|
raise ValueError("only float is supported")
|
|
return [Float16._get_c_pointer(self.value)]
|
|
|
|
|
|
class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16):
|
|
def __c_pointers__(self):
|
|
if not isinstance(self.value, float):
|
|
raise ValueError("only float is supported")
|
|
|
|
return Float.__c_pointers__(self)
|
|
|
|
|
|
class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ...
|
|
|
|
|
|
class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ...
|
|
|
|
|
|
class Float8E4M3B11FNUZ(
|
|
Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ
|
|
): ...
|
|
|
|
|
|
|
|
# Added missing float types
|
|
class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ...
|
|
|
|
|
|
class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ...
|
|
|
|
|
|
class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ...
|
|
|
|
|
|
class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ...
|
|
|
|
|
|
class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ...
|
|
|
|
|
|
_unsupported_dst_float_types = [
|
|
Float8E4M3,
|
|
Float8E4M3B11FNUZ,
|
|
Float4E2M1FN,
|
|
Float6E3M2FN,
|
|
Float6E2M3FN,
|
|
]
|
|
|
|
|
|
ALL_DTYPES = {
|
|
Int8,
|
|
Int16,
|
|
Int32,
|
|
Int64,
|
|
Int128,
|
|
Uint8,
|
|
Uint16,
|
|
Uint32,
|
|
Uint64,
|
|
Uint128,
|
|
BFloat16,
|
|
Float16,
|
|
Float32,
|
|
TFloat32,
|
|
Float64,
|
|
Float8E5M2,
|
|
Float8E4M3,
|
|
Float8E4M3FN,
|
|
Float8E8M0FNU,
|
|
Float8E4M3B11FNUZ,
|
|
Float4E2M1FN,
|
|
Float6E2M3FN,
|
|
Float6E3M2FN,
|
|
}
|
|
__STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES}
|
|
|
|
|
|
def dtype(dtype_) -> Type[Numeric]:
|
|
t = None
|
|
if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__):
|
|
t = __STR_TO_DTYPE__[dtype_]
|
|
else:
|
|
raise TypeError(f"can't interpret {dtype_} as data type")
|
|
|
|
return t
|
|
|
|
|
|
##############################################################
|
|
# Tensor
|
|
##############################################################
|
|
|
|
|
|
class TensorMeta(DslType):
|
|
_element_type = Any
|
|
_shape = Any
|
|
|
|
"""
|
|
Examples:
|
|
>>> Tensor[Int32, (3,)]
|
|
>>> Tensor[Float32, (3, 4)]
|
|
>>> T = TypeVar("T")
|
|
>>> Tensor[T, (3, 4, 5)]
|
|
"""
|
|
|
|
def __new__(cls, name, bases, attrs, element_type=Any, shape=Any):
|
|
new_cls = super().__new__(cls, name, bases, attrs)
|
|
new_cls._element_type = element_type
|
|
new_cls._shape = shape
|
|
return new_cls
|
|
|
|
|
|
# Generic type
|
|
TY = TypeVar("TY")
|
|
|
|
|
|
class Constexpr(Generic[TY]):
|
|
"""Value is passed and computed by python interpreter"""
|
|
|
|
pass
|
|
|
|
|
|
class align:
|
|
def __init__(self, value: int):
|
|
if value <= 0 or (value & (value - 1)) != 0:
|
|
raise DSLRuntimeError("expects align be power of 2 as positive value")
|
|
self._value = value
|
|
|
|
def __str__(self):
|
|
return f"align({self._value})"
|
|
|
|
|
|
class PointerMeta(DslType):
|
|
def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)):
|
|
new_cls = super().__new__(
|
|
cls,
|
|
name,
|
|
bases,
|
|
attrs,
|
|
mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get(
|
|
value_type.mlir_type, getattr(ir, "Attribute").parse("0")
|
|
),
|
|
)
|
|
new_cls._value_type = value_type
|
|
new_cls._align = align_
|
|
return new_cls
|
|
|
|
def __eq__(cls, other):
|
|
if not isinstance(other, PointerMeta):
|
|
return False
|
|
return (
|
|
cls._value_type == other._value_type
|
|
and cls._align._value == other._align._value
|
|
) # Compare alignment values
|
|
|
|
def __hash__(cls):
|
|
return hash((cls._value_type, cls._align._value)) # Hash alignment value
|
|
|
|
def __getitem__(cls, params) -> Type["Pointer"]:
|
|
value_type, align_ = params
|
|
|
|
if not isinstance(align_, align):
|
|
raise DSLRuntimeError(f"expects align but got {align_}")
|
|
|
|
# Create new class with proper name and parameters
|
|
new_cls = type(
|
|
f"Pointer[{value_type.__name__}, {align_}]",
|
|
(Pointer,),
|
|
{},
|
|
value_type=value_type,
|
|
align_=align_, # Pass alignment to __new__
|
|
)
|
|
return new_cls
|
|
|
|
def __str__(cls):
|
|
return f"ptr<{cls._value_type}, {cls._align}>"
|
|
|
|
|
|
class Pointer(metaclass=PointerMeta):
|
|
"""
|
|
A pointer to a memory location.
|
|
|
|
Examples:
|
|
|
|
def foo(a : Pointer[Int32, align=8]):
|
|
...
|
|
|
|
"""
|
|
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __str__(self):
|
|
return f"{self.value} : {type(self)}"
|
|
|
|
|
|
class IRConst(Generic[TY]):
|
|
"""Value is passed as MLIR constant value for (arith.constant)."""
|
|
|
|
def __init__(self, ty: TY):
|
|
self.ty = ty
|
|
|
|
|
|
class IRValue(Generic[TY]):
|
|
"""Value is passed as MLIR dynamic value."""
|
|
|
|
def __init__(self, ty: TY):
|
|
self.ty = ty
|
|
|
|
|
|
class IRVariadic:
|
|
"""
|
|
A helper class to pass a variadic number of arguments to a function.
|
|
"""
|
|
|
|
def __init__(self, operands):
|
|
"""
|
|
Create a list of variadic operands. `operands` must be dynamic values.
|
|
"""
|
|
self.operands = operands
|
|
|
|
def block_arg_types(self):
|
|
"""
|
|
Return the list of block args types.
|
|
"""
|
|
return [operand.type for operand in self.operands]
|
|
|
|
def set_func_args(self, block_args):
|
|
"""
|
|
This function is called after entering a function. `block_args` are the
|
|
block arguments that correspond to the passed operands. Derived classes
|
|
may implement this function to provide convenience getters for block
|
|
arguments.
|
|
"""
|
|
pass
|
|
|
|
def __len__(self):
|
|
"""
|
|
Return the length of variadic operands.
|
|
"""
|
|
return len(self.operands)
|
|
|
|
|
|
class FuncArgWithAttr(IRValue):
|
|
"""
|
|
This derived class is specifically for func op arg with attr
|
|
"""
|
|
|
|
def __init__(self, ty, attr_name, attr_ty, attr_value=None):
|
|
super().__init__(ty)
|
|
assert attr_name is not None and (
|
|
attr_ty is not None or attr_value is not None
|
|
), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr"
|
|
self.attr_name = attr_name
|
|
self.attr_ty = attr_ty
|
|
self.attr_value = attr_value
|
|
|
|
|
|
|
|
def implicitDowncastNumericType(value):
|
|
if isinstance(value, Numeric):
|
|
return value.ir_value()
|
|
return value
|
|
|
|
|
|
__all__ = [
|
|
"DslType",
|
|
"Numeric",
|
|
"NumericMeta",
|
|
"IntegerMeta",
|
|
"FloatMeta",
|
|
"Boolean",
|
|
"Integer",
|
|
"Int16",
|
|
"Int32",
|
|
"Int64",
|
|
"Int128",
|
|
"Int8",
|
|
"Uint8",
|
|
"Uint16",
|
|
"Uint32",
|
|
"Uint64",
|
|
"Uint128",
|
|
"Float",
|
|
"Float16",
|
|
"BFloat16",
|
|
"TFloat32",
|
|
"Float32",
|
|
"Float64",
|
|
"Float8E5M2",
|
|
"Float8E4M3",
|
|
"Float8E4M3FN",
|
|
"Float8E4M3B11FNUZ",
|
|
"Float8E4M3",
|
|
"Float8E8M0FNU",
|
|
"Float4E2M1FN",
|
|
"Float6E2M3FN",
|
|
"Float6E3M2FN",
|
|
"as_numeric",
|
|
"align",
|
|
"Pointer",
|
|
"dtype",
|
|
"Constexpr",
|
|
"IRConst",
|
|
"IRValue",
|
|
"IRVariadic",
|
|
"implicitDowncastNumericType",
|
|
]
|