764 lines
23 KiB
Python
764 lines
23 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
|
|
import dataclasses
|
|
import itertools as it
|
|
from types import SimpleNamespace
|
|
|
|
from ..base_dsl.typing import as_numeric, Numeric, Constexpr
|
|
from ..base_dsl._mlir_helpers.arith import ArithValue
|
|
from ..base_dsl.common import DSLBaseError
|
|
from .._mlir import ir
|
|
|
|
# =============================================================================
|
|
# Tree Utils
|
|
# =============================================================================
|
|
|
|
|
|
class DSLTreeFlattenError(DSLBaseError):
|
|
"""Exception raised when tree flattening fails due to unsupported types."""
|
|
|
|
def __init__(self, msg: str, type_str: str):
|
|
super().__init__(msg)
|
|
self.type_str = type_str
|
|
|
|
|
|
def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
|
|
"""Unzip a sequence of pairs into two lists."""
|
|
lst1, lst2 = [], []
|
|
for x1, x2 in pairs:
|
|
lst1.append(x1)
|
|
lst2.append(x2)
|
|
return lst1, lst2
|
|
|
|
|
|
def get_fully_qualified_class_name(x: Any) -> str:
|
|
"""
|
|
Get the fully qualified class name of an object.
|
|
|
|
Args:
|
|
x: Any object
|
|
|
|
Returns:
|
|
str: Fully qualified class name in format 'module.class_name'
|
|
|
|
Example:
|
|
>>> get_fully_qualified_class_name([1, 2, 3])
|
|
'builtins.list'
|
|
"""
|
|
return f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
|
|
|
|
|
def is_frozen_dataclass(obj_or_cls: Any) -> bool:
|
|
"""
|
|
Check if an object or class is a frozen dataclass.
|
|
|
|
Args:
|
|
obj_or_cls: Either a dataclass instance or class
|
|
|
|
Returns:
|
|
bool: True if the object/class is a dataclass declared with frozen=True,
|
|
False otherwise
|
|
|
|
Example:
|
|
>>> from dataclasses import dataclass
|
|
>>> @dataclass(frozen=True)
|
|
... class Point:
|
|
... x: int
|
|
... y: int
|
|
>>> is_frozen_dataclass(Point)
|
|
True
|
|
>>> is_frozen_dataclass(Point(1, 2))
|
|
True
|
|
"""
|
|
cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__
|
|
|
|
return (
|
|
dataclasses.is_dataclass(cls)
|
|
and getattr(cls, "__dataclass_params__", None) is not None
|
|
and cls.__dataclass_params__.frozen
|
|
)
|
|
|
|
|
|
def is_dynamic_expression(x: Any) -> bool:
|
|
"""
|
|
Check if an object implements the DynamicExpression protocol.
|
|
|
|
Objects implementing this protocol must have both `__extract_mlir_values__`
|
|
and `__new_from_mlir_values__` methods.
|
|
|
|
Args:
|
|
x: Any object to check
|
|
|
|
Returns:
|
|
bool: True if the object implements the DynamicExpression protocol,
|
|
False otherwise
|
|
"""
|
|
return all(
|
|
hasattr(x, attr)
|
|
for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
|
|
)
|
|
|
|
|
|
def is_constexpr_field(field: dataclasses.Field) -> bool:
|
|
"""
|
|
Check if a field is a constexpr field.
|
|
"""
|
|
if field.type is Constexpr:
|
|
return True
|
|
elif get_origin(field.type) is Constexpr:
|
|
return True
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# PyTreeDef
|
|
# =============================================================================
|
|
|
|
class NodeType(NamedTuple):
|
|
"""
|
|
Represents a node in a pytree structure.
|
|
|
|
Attributes:
|
|
name: String representation of the node type
|
|
to_iterable: Function to convert node to iterable form
|
|
from_iterable: Function to reconstruct node from iterable form
|
|
"""
|
|
name: str
|
|
to_iterable: Callable
|
|
from_iterable: Callable
|
|
|
|
|
|
class PyTreeDef(NamedTuple):
|
|
"""
|
|
Represents the structure definition of a pytree.
|
|
|
|
Attributes:
|
|
node_type: The type of this node
|
|
node_metadata: SimpleNamespace metadata associated with this node
|
|
child_treedefs: Tuple of child tree definitions
|
|
"""
|
|
node_type: NodeType
|
|
node_metadata: SimpleNamespace
|
|
child_treedefs: tuple["PyTreeDef", ...]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Leaf:
|
|
"""
|
|
Represents a leaf node in a pytree structure.
|
|
|
|
Attributes:
|
|
is_numeric: Whether this leaf contains a `Numeric` value
|
|
is_none: Whether this leaf represents None
|
|
node_metadata: SimpleNamespace metadata associated with this leaf
|
|
ir_type_str: String representation of the IR type
|
|
"""
|
|
is_numeric: bool = False
|
|
is_none: bool = False
|
|
node_metadata: SimpleNamespace = None
|
|
ir_type_str: str = None
|
|
|
|
|
|
# =============================================================================
|
|
# Default to_iterable and from_iterable
|
|
# =============================================================================
|
|
|
|
|
|
def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
|
|
"""
|
|
Extract non-method, non-function attributes from a dataclass instance.
|
|
|
|
Args:
|
|
x: A dataclass instance
|
|
|
|
Returns:
|
|
tuple: (field_names, field_values) lists
|
|
"""
|
|
fields = [field.name for field in dataclasses.fields(x)]
|
|
|
|
# If the dataclass has extra fields, raise an error
|
|
for k in x.__dict__.keys():
|
|
if k not in fields:
|
|
raise DSLTreeFlattenError(
|
|
f"`{x}` has extra field `{k}`",
|
|
type_str=get_fully_qualified_class_name(x),
|
|
)
|
|
|
|
if not fields:
|
|
return [], []
|
|
|
|
# record constexpr fields
|
|
members = []
|
|
constexpr_fields = []
|
|
for field in dataclasses.fields(x):
|
|
if is_constexpr_field(field):
|
|
constexpr_fields.append(field.name)
|
|
fields.remove(field.name)
|
|
v = getattr(x, field.name)
|
|
if is_dynamic_expression(v):
|
|
raise DSLTreeFlattenError(
|
|
f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
|
|
type_str=get_fully_qualified_class_name(x),
|
|
)
|
|
else:
|
|
members.append(getattr(x, field.name))
|
|
|
|
return fields, members, constexpr_fields
|
|
|
|
|
|
def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
|
"""
|
|
Convert a dataclass instance to iterable form for tree flattening.
|
|
|
|
Extracts all non-method, non-function attributes that don't start with '__'
|
|
and returns them along with metadata about the dataclass.
|
|
|
|
Args:
|
|
x: A dataclass instance
|
|
|
|
Returns:
|
|
tuple: (metadata, members) where metadata contains type info and field names,
|
|
and members is the list of attribute values
|
|
"""
|
|
fields, members, constexpr_fields = extract_dataclass_members(x)
|
|
|
|
metadata = SimpleNamespace(
|
|
type_str=get_fully_qualified_class_name(x),
|
|
fields=fields,
|
|
constexpr_fields=constexpr_fields,
|
|
original_obj=x,
|
|
)
|
|
return metadata, members
|
|
|
|
|
|
def set_dataclass_attributes(
|
|
instance: Any,
|
|
fields: list[str],
|
|
values: Iterable[Any],
|
|
constexpr_fields: list[str],
|
|
) -> Any:
|
|
"""
|
|
Set attributes on a dataclass instance.
|
|
|
|
Args:
|
|
instance: The dataclass instance
|
|
fields: List of field names
|
|
values: Iterable of field values
|
|
is_frozen: Whether the dataclass is frozen
|
|
|
|
Returns:
|
|
The instance with attributes set
|
|
"""
|
|
if not fields:
|
|
return instance
|
|
|
|
kwargs = dict(zip(fields, values))
|
|
for field in constexpr_fields:
|
|
kwargs[field] = getattr(instance, field)
|
|
return dataclasses.replace(instance, **kwargs)
|
|
|
|
def default_dataclass_from_iterable(
|
|
metadata: SimpleNamespace, children: Iterable[Any]
|
|
) -> Any:
|
|
"""
|
|
Reconstruct a dataclass instance from iterable form.
|
|
|
|
Handles both regular and frozen dataclasses appropriately.
|
|
|
|
Args:
|
|
metadata: Metadata containing type information and field names
|
|
children: Iterable of attribute values to reconstruct the instance
|
|
|
|
Returns:
|
|
The reconstructed dataclass instance
|
|
"""
|
|
instance = metadata.original_obj
|
|
|
|
new_instance = set_dataclass_attributes(
|
|
instance, metadata.fields, children, metadata.constexpr_fields
|
|
)
|
|
metadata.original_obj = new_instance
|
|
return new_instance
|
|
|
|
|
|
def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
|
"""
|
|
Convert a dynamic expression to iterable form.
|
|
|
|
Uses the object's `__extract_mlir_values__` method to extract MLIR values.
|
|
|
|
Args:
|
|
x: A dynamic expression object
|
|
|
|
Returns:
|
|
tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
|
|
and mlir_values are the extracted MLIR values
|
|
"""
|
|
return (
|
|
SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
|
x.__extract_mlir_values__(),
|
|
)
|
|
|
|
|
|
def dynamic_expression_from_iterable(
|
|
metadata: SimpleNamespace, children: Iterable[Any]
|
|
) -> Any:
|
|
"""
|
|
Reconstruct a dynamic expression from iterable form.
|
|
|
|
Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.
|
|
|
|
Args:
|
|
metadata: Metadata containing the original object
|
|
children: Iterable of MLIR values to reconstruct from
|
|
|
|
Returns:
|
|
The reconstructed dynamic expression object
|
|
"""
|
|
return metadata.original_obj.__new_from_mlir_values__(list(children))
|
|
|
|
|
|
def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
|
"""
|
|
Convert a dict to iterable form.
|
|
"""
|
|
if isinstance(x, SimpleNamespace):
|
|
keys = list(x.__dict__.keys())
|
|
values = list(x.__dict__.values())
|
|
else:
|
|
keys = list(x.keys())
|
|
values = list(x.values())
|
|
|
|
return (
|
|
SimpleNamespace(
|
|
type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
|
|
),
|
|
values,
|
|
)
|
|
|
|
|
|
def default_dict_from_iterable(
|
|
metadata: SimpleNamespace, children: Iterable[Any]
|
|
) -> Any:
|
|
"""
|
|
Reconstruct a dict from iterable form.
|
|
"""
|
|
instance = metadata.original_obj
|
|
fields = metadata.fields
|
|
is_simple_namespace = isinstance(instance, SimpleNamespace)
|
|
|
|
for k, v in zip(fields, children):
|
|
if is_simple_namespace:
|
|
setattr(instance, k, v)
|
|
else:
|
|
instance[k] = v
|
|
|
|
return instance
|
|
|
|
|
|
# =============================================================================
|
|
# Register pytree nodes
|
|
# =============================================================================
|
|
|
|
_node_types: dict[type, NodeType] = {}
|
|
|
|
|
|
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
|
|
"""
|
|
Register a new node type for pytree operations.
|
|
|
|
Args:
|
|
ty: The type to register
|
|
to_iter: Function to convert instances of this type to iterable form
|
|
from_iter: Function to reconstruct instances of this type from iterable form
|
|
|
|
Returns:
|
|
NodeType: The created NodeType instance
|
|
"""
|
|
nt = NodeType(str(ty), to_iter, from_iter)
|
|
_node_types[ty] = nt
|
|
return nt
|
|
|
|
|
|
def register_default_node_types() -> None:
|
|
"""Register default node types for pytree operations."""
|
|
default_registrations = [
|
|
(
|
|
tuple,
|
|
lambda t: (SimpleNamespace(length=len(t)), list(t)),
|
|
lambda _, xs: tuple(xs),
|
|
),
|
|
(
|
|
list,
|
|
lambda l: (SimpleNamespace(length=len(l)), list(l)),
|
|
lambda _, xs: list(xs),
|
|
),
|
|
(
|
|
dict,
|
|
default_dict_to_iterable,
|
|
default_dict_from_iterable,
|
|
),
|
|
(
|
|
SimpleNamespace,
|
|
default_dict_to_iterable,
|
|
default_dict_from_iterable,
|
|
),
|
|
]
|
|
|
|
for ty, to_iter, from_iter in default_registrations:
|
|
register_pytree_node(ty, to_iter, from_iter)
|
|
|
|
|
|
# Initialize default registrations
|
|
register_default_node_types()
|
|
|
|
|
|
# =============================================================================
|
|
# tree_flatten and tree_unflatten
|
|
# =============================================================================
|
|
|
|
"""
|
|
Behavior of tree_flatten and tree_unflatten, for example:
|
|
|
|
```python
|
|
a = (1, 2, 3)
|
|
b = MyClass(a=1, b =[1,2,3])
|
|
```
|
|
|
|
yields the following tree:
|
|
|
|
```python
|
|
tree_a = PyTreeDef(type = 'tuple',
|
|
metadata = {length = 3},
|
|
children = [
|
|
Leaf(type = int),
|
|
Leaf(type = int),
|
|
Leaf(type = int),
|
|
],
|
|
)
|
|
flattened_a = [1, 2, 3]
|
|
tree_b = PyTreeDef(type = 'MyClass',
|
|
metadata = {fields = ['a','b']},
|
|
children = [
|
|
PyTreeDef(type = `list`,
|
|
metadata = {length = 3},
|
|
children = [
|
|
Leaf(type=`int`),
|
|
Leaf(type=`int`),
|
|
Leaf(type=`int`),
|
|
],
|
|
),
|
|
Leaf(type=int),
|
|
],
|
|
)
|
|
flattened_b = [1, 1, 2, 3]
|
|
```
|
|
|
|
Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.
|
|
|
|
``` python
|
|
unflattened_a = tree_unflatten(tree_a, flattened_a)
|
|
unflattened_b = tree_unflatten(tree_b, flattened_b)
|
|
```
|
|
|
|
yields the following structure:
|
|
|
|
``` python
|
|
unflattened_a = (1, 2, 3)
|
|
unflattened_b = MyClass(a=1, b =[1,2,3])
|
|
```
|
|
|
|
unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.
|
|
|
|
"""
|
|
|
|
|
|
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
|
|
"""
|
|
Flatten a nested structure into a flat list of values and a tree definition.
|
|
|
|
This function recursively traverses nested data structures (trees) and
|
|
flattens them into a linear list of leaf values, while preserving the
|
|
structure information in a PyTreeDef.
|
|
|
|
Args:
|
|
x: The nested structure to flatten
|
|
|
|
Returns:
|
|
tuple: (flat_values, treedef) where flat_values is a list of leaf values
|
|
and treedef is the tree structure definition
|
|
|
|
Raises:
|
|
DSLTreeFlattenError: If the structure contains unsupported types
|
|
|
|
Example:
|
|
>>> tree_flatten([1, [2, 3], 4])
|
|
([1, 2, 3, 4], PyTreeDef(...))
|
|
"""
|
|
children_iter, treedef = _tree_flatten(x)
|
|
return list(children_iter), treedef
|
|
|
|
|
|
def get_registered_node_types_or_insert(x: Any) -> NodeType | None:
|
|
"""
|
|
Get the registered node type for an object, registering it if necessary.
|
|
|
|
This function checks if a type is already registered for pytree operations.
|
|
If not, it automatically registers the type based on its characteristics:
|
|
- Dynamic expressions get registered with dynamic expression handlers
|
|
- Dataclasses get registered with default dataclass handlers
|
|
|
|
Args:
|
|
x: The object to get or register a node type for
|
|
|
|
Returns:
|
|
NodeType or None: The registered node type, or None if the type
|
|
cannot be registered
|
|
"""
|
|
node_type = _node_types.get(type(x))
|
|
if node_type:
|
|
return node_type
|
|
elif is_dynamic_expression(x):
|
|
# If a class implements DynamicExpression protocol, register it before default dataclass one
|
|
return register_pytree_node(
|
|
type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
|
|
)
|
|
elif dataclasses.is_dataclass(x):
|
|
return register_pytree_node(
|
|
type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
|
|
)
|
|
else:
|
|
return None
|
|
|
|
|
|
def create_leaf_for_value(
|
|
x: Any,
|
|
is_numeric: bool = False,
|
|
is_none: bool = False,
|
|
node_metadata: SimpleNamespace = None,
|
|
ir_type_str: str = None,
|
|
) -> Leaf:
|
|
"""
|
|
Create a Leaf node for a given value.
|
|
|
|
Args:
|
|
x: The value to create a leaf for
|
|
is_numeric: Whether this is a numeric value
|
|
is_none: Whether this represents None
|
|
node_metadata: Optional metadata
|
|
ir_type_str: Optional IR type string
|
|
|
|
Returns:
|
|
Leaf: The created leaf node
|
|
"""
|
|
return Leaf(
|
|
is_numeric=is_numeric,
|
|
is_none=is_none,
|
|
node_metadata=node_metadata,
|
|
ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
|
|
)
|
|
|
|
|
|
def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]:
|
|
"""
|
|
Internal function to flatten a tree structure.
|
|
|
|
This is the core implementation of tree flattening that handles different
|
|
types of objects including None, ArithValue, ir.Value, Numeric types,
|
|
and registered pytree node types.
|
|
|
|
Args:
|
|
x: The object to flatten
|
|
|
|
Returns:
|
|
tuple: (flattened_values, treedef) where flattened_values is an iterable
|
|
of leaf values and treedef is the tree structure
|
|
|
|
Raises:
|
|
DSLTreeFlattenError: If the object type is not supported
|
|
"""
|
|
match x:
|
|
case None:
|
|
return [], create_leaf_for_value(x, is_none=True)
|
|
|
|
case ArithValue() if is_dynamic_expression(x):
|
|
v = x.__extract_mlir_values__()
|
|
return v, create_leaf_for_value(
|
|
x,
|
|
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
|
ir_type_str=str(v[0].type),
|
|
)
|
|
|
|
case ArithValue():
|
|
return [x], create_leaf_for_value(x, is_numeric=True)
|
|
|
|
case ir.Value():
|
|
return [x], create_leaf_for_value(x)
|
|
|
|
case Numeric():
|
|
v = x.__extract_mlir_values__()
|
|
return v, create_leaf_for_value(
|
|
x,
|
|
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
|
ir_type_str=str(v[0].type),
|
|
)
|
|
|
|
case _:
|
|
node_type = get_registered_node_types_or_insert(x)
|
|
if node_type:
|
|
node_metadata, children = node_type.to_iterable(x)
|
|
children_flat, child_trees = unzip2(map(_tree_flatten, children))
|
|
flattened = it.chain.from_iterable(children_flat)
|
|
return flattened, PyTreeDef(
|
|
node_type, node_metadata, tuple(child_trees)
|
|
)
|
|
|
|
# Try to convert to numeric
|
|
try:
|
|
nval = as_numeric(x).ir_value()
|
|
return [nval], create_leaf_for_value(nval, is_numeric=True)
|
|
except Exception:
|
|
raise DSLTreeFlattenError(
|
|
"Flatten Error", get_fully_qualified_class_name(x)
|
|
)
|
|
|
|
|
|
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
|
|
"""
|
|
Reconstruct a nested structure from a flat list of values and tree definition.
|
|
|
|
This is the inverse operation of tree_flatten. It takes the flattened
|
|
values and the tree structure definition to reconstruct the original
|
|
nested structure.
|
|
|
|
Args:
|
|
treedef: The tree structure definition from tree_flatten
|
|
xs: List of flat values to reconstruct from
|
|
|
|
Returns:
|
|
The reconstructed nested structure
|
|
|
|
Example:
|
|
>>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
|
|
>>> tree_unflatten(treedef, flat_values)
|
|
[1, [2, 3], 4]
|
|
"""
|
|
return _tree_unflatten(treedef, iter(xs))
|
|
|
|
|
|
def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any:
|
|
"""
|
|
Internal function to reconstruct a tree structure.
|
|
|
|
This is the core implementation of tree unflattening that handles
|
|
different types of tree definitions including Leaf nodes and PyTreeDef nodes.
|
|
|
|
Args:
|
|
treedef: The tree structure definition
|
|
xs: Iterator of flat values to reconstruct from
|
|
|
|
Returns:
|
|
The reconstructed object
|
|
"""
|
|
match treedef:
|
|
case Leaf(is_none=True):
|
|
return None
|
|
|
|
case Leaf(
|
|
node_metadata=metadata
|
|
) if metadata and metadata.is_dynamic_expression:
|
|
return metadata.original_obj.__new_from_mlir_values__([next(xs)])
|
|
|
|
case Leaf(is_numeric=True):
|
|
return as_numeric(next(xs))
|
|
|
|
case Leaf():
|
|
return next(xs)
|
|
|
|
case PyTreeDef():
|
|
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
|
|
return treedef.node_type.from_iterable(treedef.node_metadata, children)
|
|
|
|
|
|
def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
|
|
"""
|
|
Check if two tree definitions are structurally equal.
|
|
|
|
This is a helper function for check_tree_equal that recursively compares
|
|
tree structures.
|
|
|
|
Args:
|
|
lhs: Left tree definition (PyTreeDef or Leaf)
|
|
rhs: Right tree definition (PyTreeDef or Leaf)
|
|
|
|
Returns:
|
|
bool: True if the trees are structurally equal, False otherwise
|
|
"""
|
|
match (lhs, rhs):
|
|
case (Leaf(), Leaf()):
|
|
return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
|
|
|
|
case (PyTreeDef(), PyTreeDef()):
|
|
lhs_metadata = lhs.node_metadata
|
|
rhs_metadata = rhs.node_metadata
|
|
|
|
lhs_fields = getattr(lhs_metadata, "fields", [])
|
|
rhs_fields = getattr(rhs_metadata, "fields", [])
|
|
lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
|
|
rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])
|
|
|
|
return (
|
|
lhs.node_type == rhs.node_type
|
|
and lhs_fields == rhs_fields
|
|
and lhs_constexpr_fields == rhs_constexpr_fields
|
|
and len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
|
and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
|
|
)
|
|
|
|
case _:
|
|
return False
|
|
|
|
|
|
def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
|
|
"""
|
|
Check if two tree definitions are equal and return the index of first difference.
|
|
|
|
This function compares two tree definitions and returns the index of the
|
|
first child that differs, or -1 if they are completely equal.
|
|
|
|
Args:
|
|
lhs: Left tree definition
|
|
rhs: Right tree definition
|
|
|
|
Returns:
|
|
int: Index of the first differing child, or -1 if trees are equal
|
|
|
|
Example:
|
|
>>> treedef1 = tree_flatten([1, [2, 3]])[1]
|
|
>>> treedef2 = tree_flatten([1, [2, 4]])[1]
|
|
>>> check_tree_equal(treedef1, treedef2)
|
|
1 # The second child differs
|
|
"""
|
|
assert len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
|
|
|
def find_first_difference(
|
|
index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]]
|
|
) -> int:
|
|
index, (l, r) = index_and_pair
|
|
return index if not _check_tree_equal(l, r) else -1
|
|
|
|
differences = map(
|
|
find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
|
|
)
|
|
return next((diff for diff in differences if diff != -1), -1)
|