[Misc] Add indirection layer for custom ops (#3913)

This commit is contained in:
Kunshang Ji
2024-04-11 03:26:07 +00:00
committed by GitHub
parent e42df7227d
commit e9da5a40c6
14 changed files with 224 additions and 32 deletions

View File

@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import cache_ops, ops
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)

View File

@ -4,7 +4,7 @@ from typing import Tuple
import pytest
import torch
from vllm._C import cache_ops
from vllm import _custom_ops as ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@ -80,7 +80,7 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
ops.copy_blocks(key_caches, value_caches, block_mapping)
# Run the reference implementation.
for src, dsts in block_mapping.items():
@ -145,9 +145,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, cloned_key_cache)
ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, cloned_value_cache)
ops.convert_fp8(value_cache, cloned_value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
@ -156,14 +156,14 @@ def test_reshape_and_cache(
kv_scale = 1.0
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, result_key_cache)
ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, result_value_cache)
ops.convert_fp8(value_cache, result_value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@ -251,9 +251,8 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
for src, dst in block_mapping.items():
assert torch.allclose(src_key_caches_clone[src].cpu(),
@ -291,9 +290,9 @@ def test_fp8_conversion(
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
cache_ops.convert_fp8(cache, cache_fp8)
ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache)
cache_ops.convert_fp8(cache_fp8, converted_cache)
ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)