Compare commits

...

13 Commits

Author SHA1 Message Date
afcb616e89 Refactor test_decorator.py to use shared testing_ops module
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-21 22:59:21 +00:00
2b81d5fd2f Merge branch 'main' into copilot/fix-c6914add-1b66-46d0-9948-c2e7b6f2259f
# Conflicts:
#	tests/compile/piecewise/test_multiple_graphs.py
2025-08-21 22:50:17 +00:00
dd72729634 Rename test_operations.py to testing_ops.py and update all imports
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-21 22:45:35 +00:00
8393419f4a shorten comment
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-08-21 18:40:24 -04:00
019b2fb6ca add whitespace
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-08-21 18:38:50 -04:00
17d6532086 add whitespace back
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-08-21 18:38:39 -04:00
b0974809c4 add back whitespace
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-08-21 18:37:26 -04:00
91735e9c1c Address PR feedback: simplify comments, remove extra assertion, and improve docstrings
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-21 22:26:09 +00:00
865b0bfafd Simplify operation implementation: remove mode switching, always use global counter
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-20 23:33:41 +00:00
2c81fbbb3c Refactor duplicate torch operation registrations to use shared module
Instead of changing library names (not scalable), create a shared test_operations.py module that:
- Provides a single "silly" library for all compilation tests
- Registers a unified attention operation that can handle both standard and counting modes
- Eliminates duplicate registration errors when running all tests together
- Maintains backward compatibility with existing test behavior

Addresses feedback to make the solution more scalable and maintainable.

Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-20 14:12:37 +00:00
47dcf0940f Complete fix for duplicate torch operations - all ops use unique library names
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-20 13:34:27 +00:00
9d6f0372e5 Fix duplicate torch operation registrations in tests/compile
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
2025-08-20 13:31:41 +00:00
e263eccfae Initial plan 2025-08-20 13:20:01 +00:00
5 changed files with 77 additions and 114 deletions

View File

@ -6,7 +6,6 @@ are compiled and graph captured separately.
"""
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
@ -14,11 +13,12 @@ from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops
BATCH_SIZE = 32
MLP_SIZE = 128
@ -26,27 +26,6 @@ HIDDEN_SIZE = 1024
RANDOM_SEED = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@support_torch_compile
class ParentModel(nn.Module):
@ -277,9 +256,5 @@ def test_multi_graph_piecewise_compile_outputs_equal():
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
assert torch.allclose(outputs[0], outputs[1])
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
assert torch.equal(outputs[0], outputs[2])

View File

@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects.
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
@ -15,34 +14,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
# This import also automatically registers torch ops for testing (like silly.attention)
from tests.compile.testing_ops import (
get_global_counter, reset_global_counter
)
@ -58,9 +33,8 @@ class SillyModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
Overall effect with unified attention implementation:
input [0., 0.] -> final output [19., 19.]
global_counter += 2
"""
x = x + 1
@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
reset_global_counter()
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))

View File

@ -14,38 +14,15 @@ from typing import Any, Optional
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops
@dataclass

View File

@ -2,44 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from torch.library import Library
# This import automatically registers torch ops for testing (like silly.attention)
import tests.compile.testing_ops
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
CUDAGraphMode, VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
BATCH_SIZE = 32
MLP_SIZE = 128
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module,
cudagraph_runtime_mode: CUDAGraphMode):

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared PyTorch custom operations for compilation tests.
Centralizes custom operation definitions to avoid duplicate registrations.
"""
import torch
from torch.library import Library
from vllm.utils import direct_register_custom_op
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
silly_lib = Library("silly", "FRAGMENT")
# Global counter that counts the number of times attention is invoked
_global_counter = 0
def get_global_counter():
"""Get the current global counter value"""
return _global_counter
def reset_global_counter():
"""Reset the global counter to 0"""
global _global_counter
_global_counter = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""
Unified attention implementation that depends on all inputs and affects the output.
Always increments a global counter that tests can use or ignore.
"""
global _global_counter
# Always increment the global counter
_global_counter += 1
# Unified implementation that depends on all inputs
out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""Fake implementation for testing"""
return
# Register the unified attention operation
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)