692 lines
24 KiB
Python
692 lines
24 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
This module provides MLIR Arith Dialect helper functions
|
|
"""
|
|
|
|
import array
|
|
import numpy as np
|
|
|
|
from ..common import *
|
|
from ..._mlir import ir # type: ignore
|
|
from ..._mlir.extras import types as T # type: ignore
|
|
from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore
|
|
|
|
from .lru_cache_ir import lru_cache_ir
|
|
|
|
# =============================================================================
|
|
# Arith Dialect Helper functions
|
|
# =============================================================================
|
|
|
|
|
|
def recast_type(src_type, res_elem_type) -> ir.Type:
|
|
if isinstance(src_type, T.VectorType):
|
|
if src_type.scalable:
|
|
res_type = T.vector(
|
|
*src_type.shape,
|
|
res_elem_type,
|
|
scalable=src_type.scalable,
|
|
scalable_dims=src_type.scalable_dims,
|
|
)
|
|
else:
|
|
res_type = T.vector(*src_type.shape, res_elem_type)
|
|
elif isinstance(src_type, T.RankedTensorType):
|
|
res_type = T.RankedTensorType.get(
|
|
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
|
|
)
|
|
elif isinstance(src_type, T.UnrankedTensorType):
|
|
res_type = T.UnrankedTensorType.get(element_type=res_elem_type)
|
|
elif isinstance(src_type, T.MemRefType):
|
|
res_type = T.MemRefType.get(
|
|
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
|
|
)
|
|
else:
|
|
res_type = res_elem_type
|
|
return res_type
|
|
|
|
|
|
def is_scalar(ty) -> bool:
|
|
return not isinstance(
|
|
ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType)
|
|
)
|
|
|
|
|
|
def element_type(ty) -> ir.Type:
|
|
if not is_scalar(ty):
|
|
return ty.element_type
|
|
else:
|
|
return ty
|
|
|
|
|
|
def is_narrow_precision(ty) -> bool:
|
|
narrow_types = {
|
|
T.f8E8M0FNU(),
|
|
T.f8E4M3FN(),
|
|
T.f8E4M3(),
|
|
T.f8E5M2(),
|
|
T.f8E4M3B11FNUZ(),
|
|
T.f4E2M1FN(),
|
|
T.f6E3M2FN(),
|
|
T.f6E2M3FN(),
|
|
}
|
|
return ty in narrow_types
|
|
|
|
|
|
def is_float_type(ty) -> bool:
|
|
return (
|
|
arith._is_float_type(ty)
|
|
# TODO-upstream: prediction is not correct. Patch here and fix in upstream later
|
|
or is_narrow_precision(ty)
|
|
or ty in (T.bf16(), T.tf32())
|
|
)
|
|
|
|
|
|
def truncf_to_narrow(res_ty, src, loc, ip):
|
|
res_elem_ty = element_type(res_ty)
|
|
if res_elem_ty == T.f8E8M0FNU():
|
|
rnd = nvgpu.RoundingMode.RP
|
|
else:
|
|
rnd = nvgpu.RoundingMode.RN
|
|
return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip)
|
|
|
|
|
|
def extf_from_narrow(res_ty, src, loc, ip):
|
|
src_elem_ty = element_type(src.type)
|
|
|
|
# When source type is E8M0, temporary element type has to be bf16
|
|
tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16()
|
|
tmp_ty = recast_type(src.type, tmp_elem_ty)
|
|
|
|
# narrow -> bf16/f16 -> target type
|
|
tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip)
|
|
return arith.extf(res_ty, tmp, loc=loc, ip=ip)
|
|
|
|
|
|
def bitcast(src, res_elem_type, *, loc=None, ip=None):
|
|
res_type = recast_type(src.type, res_elem_type)
|
|
return arith.bitcast(res_type, src, loc=loc, ip=ip)
|
|
|
|
|
|
def cvtf(src, res_elem_type, *, loc=None, ip=None):
|
|
src_elem_type = element_type(src.type)
|
|
|
|
if res_elem_type == src_elem_type:
|
|
return src
|
|
|
|
res_type = recast_type(src.type, res_elem_type)
|
|
|
|
# Treat TF32 as F32 and use i32 as intermediate data
|
|
# TODO-upstream: update arith to support tf32 <-> f32 conversion
|
|
if src_elem_type == T.tf32():
|
|
# tf32 -> i32
|
|
tmp_type = recast_type(src.type, T.i32())
|
|
src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip)
|
|
# i32 -> f32
|
|
src = bitcast(src, T.f32(), loc=loc, ip=ip)
|
|
# f32 -> X with `cvtf` recursively
|
|
return cvtf(src, res_elem_type, loc=loc, ip=ip)
|
|
|
|
if res_elem_type == T.tf32():
|
|
# X -> f32 with `cvtf`` recursively
|
|
tmp = cvtf(src, T.f32(), loc=loc, ip=ip)
|
|
# f32 -> i32
|
|
tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip)
|
|
# i32 -> tf32
|
|
return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip)
|
|
|
|
if res_elem_type.width > src_elem_type.width:
|
|
if is_narrow_precision(src_elem_type):
|
|
return extf_from_narrow(res_type, src, loc, ip)
|
|
else:
|
|
return arith.extf(res_type, src, loc=loc, ip=ip)
|
|
else:
|
|
tmp_mlir_type = recast_type(src.type, T.f32())
|
|
|
|
# f16 -- extf -> f32 -- truncf -> bf16
|
|
# TODO-upstream: update arith to support bf16 <-> f16 conversion?
|
|
if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or (
|
|
src_elem_type == T.bf16() and res_elem_type == T.f16()
|
|
):
|
|
tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip)
|
|
return arith.truncf(res_type, tmp, loc=loc, ip=ip)
|
|
|
|
# {f8, f6, f4} -> f16, f32, ...
|
|
elif is_narrow_precision(res_elem_type):
|
|
return truncf_to_narrow(res_type, src, loc, ip)
|
|
else:
|
|
return arith.truncf(res_type, src, loc=loc, ip=ip)
|
|
|
|
|
|
def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
|
|
res_type = recast_type(src.type, res_elem_type)
|
|
# TODO-upstream: update arith to support this kind of conversion
|
|
if element_type(src.type) in (T.tf32(), T.bf16()):
|
|
src = cvtf(src, T.f32(), loc=loc, ip=ip)
|
|
|
|
if signed:
|
|
return arith.fptosi(res_type, src, loc=loc, ip=ip)
|
|
else:
|
|
return arith.fptoui(res_type, src, loc=loc, ip=ip)
|
|
|
|
|
|
def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
|
|
res_type = recast_type(src.type, res_elem_type)
|
|
|
|
orig_res_type = res_type
|
|
# TODO-upstream: update arith to support this kind of conversion
|
|
if res_elem_type in (T.tf32(), T.bf16()):
|
|
res_type = recast_type(src.type, T.f32())
|
|
|
|
if signed and element_type(src.type).width > 1:
|
|
res = arith.sitofp(res_type, src, loc=loc, ip=ip)
|
|
else:
|
|
res = arith.uitofp(res_type, src, loc=loc, ip=ip)
|
|
|
|
if orig_res_type == res_type:
|
|
return res
|
|
|
|
return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip)
|
|
|
|
|
|
def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
|
src_signed = a.signed
|
|
dst_signed = dst_elem_type.signed
|
|
src_width = element_type(a.type).width
|
|
dst_width = dst_elem_type.width
|
|
|
|
dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type)
|
|
|
|
if dst_width == src_width:
|
|
return a
|
|
elif src_signed and not dst_signed:
|
|
# Signed -> Unsigned
|
|
if dst_width > src_width:
|
|
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
|
else:
|
|
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
|
elif src_signed == dst_signed:
|
|
# Same signedness
|
|
if dst_width > src_width:
|
|
if src_signed and src_width > 1:
|
|
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
|
|
else:
|
|
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
|
else:
|
|
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
|
else:
|
|
# Unsigned -> Signed
|
|
if dst_width > src_width:
|
|
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
|
else:
|
|
# For truncation from unsigned to signed, we need to handle overflow
|
|
# First truncate to the target width
|
|
trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
|
# Then reinterpret as signed
|
|
if dst_signed:
|
|
return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip)
|
|
return trunc
|
|
|
|
|
|
# =============================================================================
|
|
# Arith Ops Emitter Helpers
|
|
# - assuming type of lhs and rhs match each other
|
|
# - op name matches python module operator
|
|
# =============================================================================
|
|
|
|
|
|
def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None):
|
|
"""
|
|
This function provides simplified interface to upstream op builder
|
|
arith.truncf(T.vector(shape, new_type), src)
|
|
|
|
is simplified as because it's element-wise op which can't change shape
|
|
arith.truncf(new_type, src)
|
|
"""
|
|
if isinstance(src, ir.Value):
|
|
src_ty = src.type
|
|
else:
|
|
src_ty = type(src).mlir_type
|
|
src = src.ir_value()
|
|
|
|
src_elem_ty = element_type(src_ty)
|
|
|
|
if src_elem_ty == res_elem_ty:
|
|
return src
|
|
elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty):
|
|
# float-to-float
|
|
return cvtf(src, res_elem_ty, loc=loc, ip=ip)
|
|
elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type(
|
|
res_elem_ty
|
|
):
|
|
if src_elem_ty.width >= res_elem_ty.width:
|
|
cast_op = arith.trunci
|
|
else:
|
|
if is_signed:
|
|
cast_op = arith.extsi
|
|
else:
|
|
cast_op = arith.extui
|
|
|
|
res_ty = recast_type(src_ty, res_elem_ty)
|
|
return cast_op(res_ty, src, loc=loc, ip=ip)
|
|
elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty):
|
|
return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip)
|
|
elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty):
|
|
return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip)
|
|
else:
|
|
raise DSLRuntimeError(
|
|
f"cast from {src_elem_ty} to {res_elem_ty} is not supported"
|
|
)
|
|
|
|
|
|
@lru_cache_ir()
|
|
def const(value, ty=None, *, loc=None, ip=None):
|
|
"""
|
|
Generates dynamic expression for constant values.
|
|
"""
|
|
from ..typing import Numeric, NumericMeta
|
|
from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type
|
|
|
|
if isinstance(value, Numeric):
|
|
value = value.value
|
|
|
|
# Early return
|
|
if is_dynamic_expression(value) and (
|
|
value.type.isinstance(value.type) or T.bool().isinstance(value.type)
|
|
):
|
|
return value
|
|
|
|
# Assume type
|
|
if ty is None:
|
|
if isinstance(value, float):
|
|
ty = T.f32()
|
|
elif isinstance(value, bool):
|
|
ty = T.bool()
|
|
elif isinstance(value, int):
|
|
ty = T.i32()
|
|
elif isinstance(value, np.ndarray):
|
|
ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype))
|
|
value = array.array(value.dtype.kind, value.flatten().tolist())
|
|
else:
|
|
raise DSLNotImplemented(f"{type(value)} is not supported")
|
|
elif isinstance(ty, NumericMeta):
|
|
ty = ty.mlir_type
|
|
elif isinstance(ty, ir.Type):
|
|
if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty):
|
|
elem_ty = ty.element_type
|
|
if isinstance(elem_ty, ir.IntegerType):
|
|
attr = ir.IntegerAttr.get(elem_ty, value)
|
|
else:
|
|
attr = ir.FloatAttr.get(elem_ty, value)
|
|
value = ir.DenseElementsAttr.get_splat(ty, attr)
|
|
elif arith._is_float_type(ty) and isinstance(value, (bool, int)):
|
|
value = float(value)
|
|
elif arith._is_integer_like_type(ty) and isinstance(value, float):
|
|
value = int(value)
|
|
else:
|
|
raise DSLNotImplemented(f"type {ty} is not supported")
|
|
|
|
return arith.constant(ty, value, loc=loc, ip=ip)
|
|
|
|
|
|
def _dispatch_to_rhs_r_op(op):
|
|
"""Decorator that dispatches to the right-hand-side's reverse operation.
|
|
|
|
If the other operand is not an ArithValue or is a subclass (more specific)
|
|
of ArithValue, this allows proper method resolution for binary operations.
|
|
"""
|
|
|
|
def wrapper(self, other, **kwargs):
|
|
if not isinstance(other, ArithValue):
|
|
if not isinstance(other, (int, float, bool)):
|
|
# allows to call other.__rmul__
|
|
return NotImplemented
|
|
|
|
return op(self, other, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def _binary_op(op):
|
|
"""
|
|
Decorator to check if the 'other' argument is an ArithValue.
|
|
If not, returns NotImplemented.
|
|
"""
|
|
|
|
def wrapper(self, other, **kwargs):
|
|
# When reach this point, `self` must be cast to base `ArithValue` type
|
|
if isinstance(other, (int, float, bool)):
|
|
other = const(other, self.type).with_signedness(self.signed)
|
|
|
|
# Call the original function
|
|
# If sub-class doesn't implement overloaded arithmetic, cast to base class
|
|
return op(self, other, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
# Operator overloading
|
|
@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid)
|
|
@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid)
|
|
@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid)
|
|
@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid)
|
|
@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid)
|
|
@ir.register_value_caster(ir.Float8E5M2Type.static_typeid)
|
|
@ir.register_value_caster(ir.Float8E4M3Type.static_typeid)
|
|
@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid)
|
|
@ir.register_value_caster(ir.BF16Type.static_typeid)
|
|
@ir.register_value_caster(ir.F16Type.static_typeid)
|
|
@ir.register_value_caster(ir.FloatTF32Type.static_typeid)
|
|
@ir.register_value_caster(ir.F32Type.static_typeid)
|
|
@ir.register_value_caster(ir.F64Type.static_typeid)
|
|
@ir.register_value_caster(ir.IntegerType.static_typeid)
|
|
@ir.register_value_caster(ir.VectorType.static_typeid)
|
|
@ir.register_value_caster(ir.RankedTensorType.static_typeid)
|
|
class ArithValue(ir.Value):
|
|
"""Overloads operators for MLIR's Arith dialects binary operations."""
|
|
|
|
def __init__(self, v, signed: Union[bool, None] = None):
|
|
if isinstance(v, int):
|
|
v = arith.constant(self.type, v)
|
|
super().__init__(v)
|
|
|
|
elem_ty = element_type(self.type)
|
|
self.is_float = arith._is_float_type(elem_ty)
|
|
# arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
|
|
self.signed = signed and elem_ty.width > 1
|
|
|
|
def with_signedness(self, signed: Union[bool, None]):
|
|
return type(self)(self, signed)
|
|
|
|
def __neg__(self, *, loc=None, ip=None):
|
|
if self.type == T.bool():
|
|
raise TypeError(
|
|
"Negation, the operator `-` is not supported for boolean type"
|
|
)
|
|
|
|
if self.is_float:
|
|
return arith.negf(self, loc=loc, ip=ip)
|
|
else:
|
|
c0 = arith.constant(self.type, 0, loc=loc, ip=ip)
|
|
return arith.subi(c0, self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float and other.is_float:
|
|
return math.powf(self, other, loc=loc, ip=ip)
|
|
elif self.is_float and not other.is_float:
|
|
return math.fpowi(self, other, loc=loc, ip=ip)
|
|
elif not self.is_float and other.is_float:
|
|
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
|
|
rhs = cvtf(other, T.f32(), loc=loc, ip=ip)
|
|
return math.powf(lhs, rhs, loc=loc, ip=ip)
|
|
elif not self.is_float and not other.is_float:
|
|
return math.ipowi(self, other, loc=loc, ip=ip)
|
|
else:
|
|
raise DSLNotImplemented(f"Unsupported '{self} ** {other}'")
|
|
|
|
@_binary_op
|
|
def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__pow__(self, loc=loc, ip=ip)
|
|
|
|
# arith operators
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __add__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.addf(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.addi(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.subf(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.subi(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.mulf(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.muli(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.divf(self, other, loc=loc, ip=ip)
|
|
else:
|
|
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
|
|
rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip)
|
|
return arith.divf(lhs, rhs, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
q = arith.divf(self, other, loc=loc, ip=ip)
|
|
return math.floor(q, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.floordivsi(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.divui(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.remf(self, other, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.remsi(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.remui(self, other, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__add__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__sub__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__mul__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__truediv__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__floordiv__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__mod__(self, loc=loc, ip=ip)
|
|
|
|
# Comparison operators (comparison doesn't have right-hand-side variants)
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
# In Python, bool(float("nan")) is True, so use unordered comparison here
|
|
return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.is_float:
|
|
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
|
|
elif self.signed:
|
|
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
|
|
|
|
# Unary operators
|
|
def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.xori(self, arith.constant(self.type, -1))
|
|
|
|
# Bitwise operations
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __and__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.andi(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __or__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.ori(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.xori(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
if self.signed:
|
|
return arith.shrsi(self, other, loc=loc, ip=ip)
|
|
else:
|
|
return arith.shrui(self, other, loc=loc, ip=ip)
|
|
|
|
@_dispatch_to_rhs_r_op
|
|
@_binary_op
|
|
def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.shli(self, other, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.andi(other, self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.ori(other, self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return arith.xori(other, self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__rshift__(self, loc=loc, ip=ip)
|
|
|
|
@_binary_op
|
|
def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
|
return other.__lshift__(self, loc=loc, ip=ip)
|
|
|
|
def __hash__(self):
|
|
return super().__hash__()
|
|
|
|
def __str__(self):
|
|
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
|
|
def _min(lhs, rhs, *, loc=None, ip=None):
|
|
"""
|
|
This function provides a unified interface for building arith min
|
|
|
|
Assuming the operands have the same type
|
|
"""
|
|
from ..dsl import is_dynamic_expression
|
|
|
|
if not is_dynamic_expression(lhs):
|
|
if not is_dynamic_expression(rhs):
|
|
return min(lhs, rhs)
|
|
else:
|
|
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
|
|
else:
|
|
if not is_dynamic_expression(rhs):
|
|
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
|
|
|
if arith._is_integer_like_type(lhs.type):
|
|
if lhs.signed:
|
|
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.minui(lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.minimumf(lhs, rhs, loc=loc, ip=ip)
|
|
|
|
|
|
def _max(lhs, rhs, *, loc=None, ip=None):
|
|
"""
|
|
This function provides a unified interface for building arith max
|
|
|
|
Assuming the operands have the same type
|
|
"""
|
|
from ..dsl import is_dynamic_expression
|
|
|
|
if not is_dynamic_expression(lhs):
|
|
if not is_dynamic_expression(rhs):
|
|
return max(lhs, rhs)
|
|
else:
|
|
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
|
|
else:
|
|
if not is_dynamic_expression(rhs):
|
|
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
|
|
|
if arith._is_integer_like_type(lhs.type):
|
|
if lhs.signed:
|
|
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.maxui(lhs, rhs, loc=loc, ip=ip)
|
|
else:
|
|
return arith.maximumf(lhs, rhs, loc=loc, ip=ip)
|