[Kernel] Fullgraph and opcheck tests (#8479)
This commit is contained in:
@ -2,12 +2,14 @@
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from numbers import Number
|
||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
|
||||
# Copied/modified from torch._refs.__init__.py
|
||||
def fp8_allclose(
|
||||
a: TensorLikeType,
|
||||
b: TensorLikeType,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Reference implementation of torch.allclose
|
||||
"""
|
||||
torch._refs._check_close_args(name="torch.allclose",
|
||||
a=a,
|
||||
b=b,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
return bool(
|
||||
torch.all(
|
||||
torch.isclose(a.double(),
|
||||
b.double(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan)).item())
|
||||
|
||||
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True) -> Dict[str, str]:
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
|
||||
Reference in New Issue
Block a user