Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-09-18 12:37:29 -07:00
parent c1d83f2bae
commit 9050087250
6 changed files with 49 additions and 33 deletions

View File

@ -4,11 +4,10 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
def get_kv_cache_spec(
@ -48,7 +47,7 @@ def init_attn_backend(
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: dict[str, AttentionMetadataBuilder] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
@ -56,15 +55,16 @@ def init_attn_backend(
any_layer_name = next(iter(layer_names))
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builders[layer_name] = attn_metadata_builder
attn_metadata_builders.append(attn_metadata_builder)
return attn_backends, attn_metadata_builders
@ -98,7 +98,7 @@ def _reshape_kv_cache(
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
@ -110,7 +110,7 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
@ -129,5 +129,6 @@ def init_kv_cache(
device: torch.device,
):
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
attn_backends)
return kv_caches

View File

@ -54,6 +54,7 @@ class InputBatch:
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs]
is_chunked_prefilling: np.ndarray

View File

@ -10,20 +10,21 @@ import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend, init_kv_cache
from vllm.v1.worker.utils import bind_kv_cache
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import DeviceMemoryProfiler, GiB_bytes
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
@ -123,7 +124,11 @@ class GPUModelRunner:
self.device,
)
kv_caches = init_kv_cache(self.kv_cache_config, self.attn_backends, self.device)
kv_caches = init_kv_cache(
self.kv_cache_config,
self.attn_backends,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
@ -134,12 +139,13 @@ class GPUModelRunner:
def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None:
return None, None
def _dummy_sampler_run(self, hidden_states: torch.Tensor, *args, **kwargs) -> None:
def _dummy_sampler_run(self, hidden_states: torch.Tensor, *args,
**kwargs) -> None:
return None
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
# for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
@ -207,6 +213,9 @@ class GPUModelRunner:
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
# TODO(woosuk): Support CUDA graphs.
num_tokens_after_padding = num_tokens
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
@ -251,8 +260,8 @@ class GPUModelRunner:
num_computed_tokens_np = self.req_states.num_computed_tokens[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
num_tokens = self.req_states.num_tokens[idx_mapping_np]
is_chunked_prefilling = seq_lens_np < num_tokens
is_chunked_prefilling = (seq_lens_np
< self.req_states.num_tokens[idx_mapping_np])
# Slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
@ -285,12 +294,12 @@ class GPUModelRunner:
)
attn_metadata_builder = self.attn_metadata_builders[i]
attn_metadata = attn_metadata_builder.build(
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = attn_metadata
attn_metadata[layer_name] = metadata
return InputBatch(
req_ids=req_ids,
@ -299,9 +308,10 @@ class GPUModelRunner:
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
input_ids=input_ids.gpu,
positions=positions.gpu,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
@ -333,7 +343,8 @@ class GPUModelRunner:
return None
num_prompt_tokens_scheduled = ...
if not np.any((num_prompt_tokens_scheduled > 0) & needs_prompt_logprobs):
if not np.any((num_prompt_tokens_scheduled > 0)
& needs_prompt_logprobs):
# The request already computed prompt logprobs.
return None

View File

@ -123,7 +123,7 @@ def _apply_temp_kernel(
if temp < EPSILON:
# Greedy sampling. Don't apply temperature.
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
temp = tl.ones([1], dtype=tl.float32)
temp = 1.0
offset = tl.arange(0, BLOCK_SIZE)
block = block_idx * BLOCK_SIZE + offset

View File

@ -100,8 +100,8 @@ class RequestState:
top_k = self.vocab_size
self.top_k[req_idx] = top_k
if sampling_params.num_logprobs is not None:
num_logprobs = sampling_params.num_logprobs
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs

View File

@ -335,7 +335,9 @@ class Worker(WorkerBase):
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
if self.model_runner.lora_config is not None:
self.model_runner.maybe_remove_all_loras(
self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before
# cuda graph capture.
@ -429,6 +431,9 @@ class Worker(WorkerBase):
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
if len(get_pp_group().ranks) == 1:
return self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -447,8 +452,6 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config