149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
import functools
|
|
import time
|
|
from typing import Tuple
|
|
|
|
import chex
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
_PAD_SLOT_ID = -1
|
|
|
|
|
|
@jax.jit
|
|
def write_to_kv_cache1(
|
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
|
) -> Tuple[jax.Array, jax.Array]:
|
|
num_heads = key.shape[-2]
|
|
head_size = key.shape[-1]
|
|
|
|
key = key.reshape(-1, num_heads, head_size)
|
|
key = key.transpose((1, 0, 2))
|
|
value = value.reshape(-1, num_heads, head_size)
|
|
value = value.transpose((1, 0, 2))
|
|
|
|
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
|
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
|
return k_cache, v_cache
|
|
|
|
|
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
|
def write_to_kv_cache2(
|
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
slot_mapping: jax.Array, # [batch_size, seq_len]
|
|
) -> Tuple[jax.Array, jax.Array]:
|
|
batch_size = slot_mapping.shape[0]
|
|
|
|
def cond(val: _IteratorState):
|
|
return val.idx < batch_size
|
|
|
|
def body(val: _IteratorState):
|
|
k_cache, v_cache = _write_seq_to_kv_cache(
|
|
key[val.idx],
|
|
value[val.idx],
|
|
val.k_cache,
|
|
val.v_cache,
|
|
slot_mapping[val.idx],
|
|
)
|
|
val.k_cache = k_cache
|
|
val.v_cache = v_cache
|
|
val.idx += 1
|
|
return val
|
|
|
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
|
return iterator.k_cache, iterator.v_cache
|
|
|
|
|
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
|
def _write_seq_to_kv_cache(
|
|
key: jax.Array, # [seq_len, num_heads, head_size]
|
|
value: jax.Array, # [seq_len, num_heads, head_size]
|
|
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
|
slot_mapping: jax.Array, # [seq_len]
|
|
) -> Tuple[jax.Array, jax.Array]:
|
|
seq_len = slot_mapping.shape[0]
|
|
num_heads, _, head_size = k_cache.shape
|
|
# Reshape to match the rank of kv_cache.
|
|
key = key.reshape(seq_len, num_heads, 1, head_size)
|
|
value = value.reshape(seq_len, num_heads, 1, head_size)
|
|
|
|
def cond(val: _IteratorState):
|
|
return jnp.logical_and(
|
|
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
|
|
|
def body(val: _IteratorState):
|
|
slot_idx = slot_mapping[val.idx]
|
|
val.k_cache = jax.lax.dynamic_update_slice(
|
|
val.k_cache,
|
|
key[val.idx],
|
|
(0, slot_idx, 0),
|
|
)
|
|
val.v_cache = jax.lax.dynamic_update_slice(
|
|
val.v_cache,
|
|
value[val.idx],
|
|
(0, slot_idx, 0),
|
|
)
|
|
val.idx += 1
|
|
return val
|
|
|
|
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
|
iterator = jax.lax.while_loop(cond, body, iterator)
|
|
return iterator.k_cache, iterator.v_cache
|
|
|
|
|
|
@chex.dataclass
|
|
class _IteratorState:
|
|
|
|
idx: jnp.int32
|
|
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
|
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
|
|
|
|
|
def benchmark_write_to_kv_cache(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
num_blocks: int,
|
|
block_size: int,
|
|
version: int = 1,
|
|
):
|
|
if version == 1:
|
|
f = write_to_kv_cache1
|
|
elif version == 2:
|
|
f = write_to_kv_cache2
|
|
else:
|
|
raise ValueError(f"Invalid version: {version}")
|
|
|
|
rng_key = jax.random.PRNGKey(0)
|
|
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
|
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
|
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
|
|
|
# For JIT compilation.
|
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
|
k_cache.block_until_ready()
|
|
|
|
start = time.time()
|
|
for _ in range(100):
|
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
|
k_cache.block_until_ready()
|
|
end = time.time()
|
|
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
|
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
|
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|