[Fix] [torch.compile] Improve UUID system for custom passes (#15249)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-03-23 21:54:07 -04:00
committed by GitHub
parent dccf535f8e
commit f622dbcf39
5 changed files with 132 additions and 91 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
import copy
import pytest
import torch
@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
# dummy custom pass that doesn't inherit
def simple_callable(graph: torch.fx.Graph):
pass
callable_uuid = CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
@pytest.mark.parametrize(
"works, callable",
[
(False, simple_callable),
(True, callable_uuid),
(True, CallableInductorPass(simple_callable)),
],
)
def test_pass_manager(works: bool, callable):
# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Try to add the callable to the pass manager
if works:
pass_manager.add(callable)
pickle.dumps(pass_manager)
else:
with pytest.raises(AssertionError):
pass_manager.add(callable)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # noqa, type wrong on purpose
# Pass that inherits from InductorPass
class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass
@pytest.mark.parametrize(
"callable",
[
ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
],
)
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()