[Kernel] Have rotary embeddings support tensors (#18046)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-05-14 18:43:55 -04:00
committed by GitHub
parent 749f792553
commit d93c976a0d
4 changed files with 59 additions and 31 deletions

View File

@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
return (batch_size, seq_len, num_heads * head_size)
# For testing sliced tensors
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size + 64)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
TENSORS_SHAPES_FN = [
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@ -79,6 +87,10 @@ def test_rotary_embedding(
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)

View File

@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len, use_key):
seq_len, use_key, head_stride_is_contingous):
batch_size = 1
base = 10000
num_heads = 7
@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
head_stride = head_size + (64 if head_stride_is_contingous else 0)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
num_heads,
head_stride,
dtype=torch.float32,
device=device)
key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if head_stride_is_contingous:
rotary_embedding_opcheck(
rot, positions, query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None)