Compare commits

...

1 Commits

Author SHA1 Message Date
bfff9bcd1d [V1] TPU - Remove self.kv_caches 2025-03-05 20:42:05 +00:00
4 changed files with 90 additions and 44 deletions

View File

@ -10,10 +10,14 @@ prompts = [
"The future of AI is", "The future of AI is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m") # llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=16,
max_model_len=128,
enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
@ -21,4 +25,4 @@ outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -12,6 +12,7 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@ -33,14 +34,16 @@ class DPMetadata:
@dataclass @dataclass
class ForwardContext: class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context # Copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any] no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context # TODO: Extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache # TODO: Remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # Set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
# Whether this is a profile run (before KV cache init)
is_profile_run: bool = False,
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
@ -58,7 +61,8 @@ def get_forward_context() -> ForwardContext:
def set_forward_context(attn_metadata: Any, def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0, virtual_engine: int = 0,
num_tokens: int = 0): num_tokens: int = 0,
is_profile_run: bool = False):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
@ -93,12 +97,15 @@ def set_forward_context(attn_metadata: Any,
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = ForwardContext( _forward_context = ForwardContext(
no_compile_layers=vllm_config.compilation_config. no_compile_layers=vllm_config.compilation_config.
static_forward_context, static_forward_context,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata) dp_metadata=dp_metadata,
is_profile_run=is_profile_run)
try: try:
yield yield
finally: finally:
@ -111,10 +118,17 @@ def set_forward_context(attn_metadata: Any,
else: else:
# for v1 attention backends # for v1 attention backends
batchsize = attn_metadata.num_input_tokens batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch
torch.cuda.synchronize() if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
xm.mark_step()
xm.wait_device_ops()
else:
torch.cuda.synchronize()
now = time.perf_counter() now = time.perf_counter()
# time measurement is in milliseconds # time measurement is in milliseconds
batchsize_forward_time[batchsize].append( batchsize_forward_time[batchsize].append(

View File

@ -30,7 +30,6 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@ -104,9 +103,6 @@ class TPUModelRunner:
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@ -582,7 +578,6 @@ class TPUModelRunner:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.position_ids, positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
hidden_states = hidden_states[:total_num_scheduled_tokens] hidden_states = hidden_states[:total_num_scheduled_tokens]
@ -680,8 +675,8 @@ class TPUModelRunner:
def _dummy_run( def _dummy_run(
self, self,
kv_caches,
num_tokens: int, num_tokens: int,
is_profile_run: bool,
) -> None: ) -> None:
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
@ -728,15 +723,28 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata,
self.vllm_config,
0,
is_profile_run=is_profile_run):
assert self.model is not None assert self.model is not None
self.model( self.model(
input_ids=input_ids, input_ids=input_ids,
positions=position_ids, positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
# This is used before KV cache init
def profile_run(self, num_tokens) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=True)
# This is used after KV cache init
def dummy_run(
self,
num_tokens: int,
) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=False)
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
@ -745,7 +753,7 @@ class TPUModelRunner:
start = time.perf_counter() start = time.perf_counter()
num_tokens = 16 num_tokens = 16
while True: while True:
self._dummy_run(self.kv_caches, num_tokens) self.dummy_run(num_tokens)
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
@ -769,6 +777,7 @@ class TPUModelRunner:
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
kv_cache_shape_prev = None
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name] tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0 assert tensor_config.size % layer_spec.page_size_bytes == 0
@ -779,6 +788,12 @@ class TPUModelRunner:
layer_spec.head_size) layer_spec.head_size)
dtype = layer_spec.dtype dtype = layer_spec.dtype
# Ensure all "kv_cache_shape" are the same across the model
if kv_cache_shape_prev is None:
kv_cache_shape_prev = kv_cache_shape
else:
assert kv_cache_shape == kv_cache_shape_prev
tpu_k_cache = torch.zeros(kv_cache_shape, tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)
@ -788,10 +803,16 @@ class TPUModelRunner:
else: else:
raise NotImplementedError raise NotImplementedError
bind_kv_cache( # ModelWrapperV1 needs to know the KV cache shape
kv_caches, self.model.set_kv_cache_shape(kv_cache_shape_prev)
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) # Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
class ModelWrapperV1(nn.Module): class ModelWrapperV1(nn.Module):
@ -799,12 +820,15 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
super().__init__() super().__init__()
self.model = model self.model = model
self.kv_cache_shape = None
def set_kv_cache_shape(self, kv_cache_shape):
self.kv_cache_shape = kv_cache_shape
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model and samples the next token.
@ -817,16 +841,20 @@ class ModelWrapperV1(nn.Module):
inputs_embeds: The input embeddings of shape [num_tokens, inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models. hidden_size]. It is used for multimodal models.
""" """
# Skip this in memory profiling at initialization. forward_context = get_forward_context()
if kv_caches[0][0].numel() > 0: attn_metadata = forward_context.attn_metadata
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension # index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape # is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it # [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify # work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly. # the slot_mapping accordingly.
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]] #
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape # Note: We skip this step during first profiling run (before KV init)
if not forward_context.is_profile_run:
assert self.kv_cache_shape # Ensure initialized
num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten() slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0, head_indicies = torch.arange(0,

View File

@ -21,7 +21,6 @@ from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.tpu_model_runner import TPUModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
@ -128,18 +127,19 @@ class TPUWorker:
else: else:
raise NotImplementedError raise NotImplementedError
runner_kv_caches: list[torch.Tensor] = [] # Associates each attention layer in the `forward_context` with the
bind_kv_cache( # initialized KV cache.
kv_caches, forward_context = self.vllm_config.compilation_config \
self.vllm_config.compilation_config.static_forward_context, .static_forward_context
runner_kv_caches) for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
self.model_runner._dummy_run( self.model_runner.profile_run(
runner_kv_caches, num_tokens=self.scheduler_config.max_num_batched_tokens)
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
# Synchronize before measuring the memory usage. # Synchronize before measuring the memory usage.
xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and # Get the maximum amount of memory used by the model weights and