Compare commits
13 Commits
dbo-cudagr
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| afcb616e89 | |||
| 2b81d5fd2f | |||
| dd72729634 | |||
| 8393419f4a | |||
| 019b2fb6ca | |||
| 17d6532086 | |||
| b0974809c4 | |||
| 91735e9c1c | |||
| 865b0bfafd | |||
| 2c81fbbb3c | |||
| 47dcf0940f | |||
| 9d6f0372e5 | |||
| e263eccfae |
@ -6,7 +6,6 @@ are compiled and graph captured separately.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -14,11 +13,12 @@ from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
# This import automatically registers torch ops for testing (like silly.attention)
|
||||
import tests.compile.testing_ops
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
@ -26,27 +26,6 @@ HIDDEN_SIZE = 1024
|
||||
RANDOM_SEED = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
@ -277,9 +256,5 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
assert torch.allclose(outputs[0], outputs[1])
|
||||
|
||||
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
|
||||
assert torch.equal(outputs[0], outputs[2])
|
||||
|
||||
@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -15,34 +14,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
global_counter += 1
|
||||
print(f"{global_counter=}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
# This import also automatically registers torch ops for testing (like silly.attention)
|
||||
from tests.compile.testing_ops import (
|
||||
get_global_counter, reset_global_counter
|
||||
)
|
||||
|
||||
|
||||
@ -58,9 +33,8 @@ class SillyModel(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x += 1
|
||||
x[0] += 2
|
||||
Overall effect with unified attention implementation:
|
||||
input [0., 0.] -> final output [19., 19.]
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
reset_global_counter()
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))
|
||||
|
||||
@ -14,38 +14,15 @@ from typing import Any, Optional
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
# This import automatically registers torch ops for testing (like silly.attention)
|
||||
import tests.compile.testing_ops
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -2,44 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
# This import automatically registers torch ops for testing (like silly.attention)
|
||||
import tests.compile.testing_ops
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
|
||||
62
tests/compile/testing_ops.py
Normal file
62
tests/compile/testing_ops.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared PyTorch custom operations for compilation tests.
|
||||
|
||||
Centralizes custom operation definitions to avoid duplicate registrations.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# Shared library for all compilation test operations
|
||||
# Using "silly" namespace to match existing test expectations
|
||||
silly_lib = Library("silly", "FRAGMENT")
|
||||
|
||||
|
||||
# Global counter that counts the number of times attention is invoked
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def get_global_counter():
|
||||
"""Get the current global counter value"""
|
||||
return _global_counter
|
||||
|
||||
|
||||
def reset_global_counter():
|
||||
"""Reset the global counter to 0"""
|
||||
global _global_counter
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Unified attention implementation that depends on all inputs and affects the output.
|
||||
Always increments a global counter that tests can use or ignore.
|
||||
"""
|
||||
global _global_counter
|
||||
|
||||
# Always increment the global counter
|
||||
_global_counter += 1
|
||||
|
||||
# Unified implementation that depends on all inputs
|
||||
out.copy_(q + k + v)
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""Fake implementation for testing"""
|
||||
return
|
||||
|
||||
|
||||
# Register the unified attention operation
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
Reference in New Issue
Block a user