Compare commits
115 Commits
main
...
woosuk/mod
| Author | SHA1 | Date | |
|---|---|---|---|
| 1c5c866559 | |||
| 5c8049d990 | |||
| 5666a25efb | |||
| 09e4b2f6eb | |||
| 110770170f | |||
| 866eef50ca | |||
| ad2cf805ad | |||
| 704def253c | |||
| 42f99150c1 | |||
| 17c2c106b1 | |||
| 72f0a71939 | |||
| fe5472dc03 | |||
| bc73f674bb | |||
| 631b5b47c1 | |||
| 42ffdd9179 | |||
| 8aee6e97e6 | |||
| 913b8e9569 | |||
| 158a46888e | |||
| 98ef239486 | |||
| a66aa37f40 | |||
| 6f038fc4fb | |||
| 010e39ec7d | |||
| 396bbe67d3 | |||
| c7f3e84b34 | |||
| a8e7071924 | |||
| 4be2c66e37 | |||
| d30c0d50a6 | |||
| 9c75d896a8 | |||
| 37478c18cf | |||
| 33672774f5 | |||
| 0d3de9e082 | |||
| b405d78c07 | |||
| 8af87986aa | |||
| af65838d1f | |||
| 52ca2f517a | |||
| 8deedfa42b | |||
| b9c74487d2 | |||
| 31619ff412 | |||
| d2be62378b | |||
| 86dade710d | |||
| efda08481b | |||
| 82da219ff9 | |||
| 323a05b3c5 | |||
| a98eff0762 | |||
| 67d8c0c21b | |||
| 2bb2cb13f4 | |||
| e171e5bb67 | |||
| 8407fa02ed | |||
| 82e591f7eb | |||
| 330058f9b8 | |||
| aabfaa08cf | |||
| bc6463ac97 | |||
| a4962833f9 | |||
| 3f50030cc8 | |||
| cbdb47dc01 | |||
| 92f337faeb | |||
| 9050087250 | |||
| c1d83f2bae | |||
| 91510260b2 | |||
| c320a33c59 | |||
| 83d11373a4 | |||
| dfc84b11a9 | |||
| 9f2becd3e6 | |||
| e107680d8a | |||
| f1981db101 | |||
| 69b17891a3 | |||
| 67852c1036 | |||
| 8b3c13c485 | |||
| 9a6fcca030 | |||
| 633f9f006d | |||
| eb3742c72a | |||
| e47bb9970b | |||
| 5c133fc860 | |||
| caf963f2e9 | |||
| 9314a83b56 | |||
| 7a50a54390 | |||
| 787e59629c | |||
| 5f95309a6d | |||
| 286eeb91e8 | |||
| 6283995a6c | |||
| 0c56069c7e | |||
| 8e6cb9aa4a | |||
| ead95fe5dc | |||
| 23eae07ea5 | |||
| b16e2d9602 | |||
| 4c2a337e67 | |||
| cc340e26af | |||
| 01bf16ede4 | |||
| af7b6c5dd4 | |||
| 62d23b3006 | |||
| ba1a58f51b | |||
| 22771e5d83 | |||
| c11d1e6781 | |||
| e696f78e05 | |||
| efcb786d52 | |||
| 9ee9d0e274 | |||
| 405578121c | |||
| 19c0dfc469 | |||
| e451045a66 | |||
| efba25e21a | |||
| b21393cd98 | |||
| d6d719fb24 | |||
| e570b0a4de | |||
| a851aaa0fc | |||
| b1d52734f7 | |||
| 65f93694be | |||
| 7b4b72e551 | |||
| da9cd26c78 | |||
| a1e3745150 | |||
| 48bca9a109 | |||
| 64c8cced18 | |||
| 79e5eb3643 | |||
| c472982746 | |||
| 699bd7928e | |||
| 33a3a26ca5 |
@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="facebook/opt-125m", compilation_config={"level": 0, "cudagraph_mode": "full_decode_only"})
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
@ -10,6 +10,7 @@ torchaudio==2.9.0
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.4.1
|
||||
apache-tvm-ffi==0.1.0b15
|
||||
|
||||
@ -46,6 +46,8 @@ def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
|
||||
|
||||
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
return
|
||||
|
||||
# Deep GEMM warmup
|
||||
do_deep_gemm_warmup = (
|
||||
envs.VLLM_USE_DEEP_GEMM
|
||||
|
||||
@ -412,6 +412,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return self._workspace_buffer
|
||||
|
||||
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
|
||||
self._workspace_buffer = workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
|
||||
@ -34,6 +34,7 @@ else:
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
prefill_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
@ -51,6 +52,7 @@ class NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prefill_token_ids=request._all_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
@ -173,6 +175,7 @@ class SchedulerOutput:
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
preempted_req_ids: set[str]
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
|
||||
@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
|
||||
# Construct the scheduler output.
|
||||
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
|
||||
scheduled_resumed_reqs = []
|
||||
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface):
|
||||
req.num_computed_tokens : req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_token_ids.append(token_ids)
|
||||
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
|
||||
if idx >= num_running_reqs:
|
||||
assert not scheduled_in_prev_step
|
||||
resumed_req_ids.add(req_id)
|
||||
if not scheduled_in_prev_step:
|
||||
all_token_ids[req_id] = req.all_token_ids[
|
||||
: req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
|
||||
)
|
||||
@ -933,7 +929,8 @@ class Scheduler(SchedulerInterface):
|
||||
# to avoid expensive operations inside the loop.
|
||||
stopped_running_reqs: set[Request] = set()
|
||||
stopped_preempted_reqs: set[Request] = set()
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
for req_index, req_id in enumerate(model_runner_output.req_ids):
|
||||
num_tokens_scheduled = num_scheduled_tokens[req_id]
|
||||
assert num_tokens_scheduled > 0
|
||||
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
|
||||
# Skip requests that were recovered from KV load failure
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
@ -395,3 +396,10 @@ class KVCacheConfig:
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
|
||||
@cached_property
|
||||
def block_sizes(self) -> list[int]:
|
||||
return [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in self.kv_cache_groups
|
||||
]
|
||||
|
||||
0
vllm/v1/worker/gpu/__init__.py
Normal file
0
vllm/v1/worker/gpu/__init__.py
Normal file
75
vllm/v1/worker/gpu/async_utils.py
Normal file
75
vllm/v1/worker/gpu/async_utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
)
|
||||
|
||||
|
||||
class AsyncOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_stream = copy_stream
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.copy_stream.wait_stream(default_stream)
|
||||
|
||||
# NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes.
|
||||
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.prompt_logprobs_dict = {}
|
||||
if self.model_runner_output.prompt_logprobs_dict:
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
self.copy_event.record(self.copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
|
||||
# NOTE(woosuk): The following code ensures compatibility with OSS vLLM.
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids_np = self.sampled_token_ids.numpy()
|
||||
sampled_token_ids = sampled_token_ids_np.tolist()
|
||||
for i, tokens in enumerate(sampled_token_ids):
|
||||
del tokens[self.num_sampled_tokens[i] :]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
|
||||
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_barrier(event: torch.cuda.Event | None):
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if event is not None:
|
||||
event.record()
|
||||
200
vllm/v1/worker/gpu/attn_utils.py
Normal file
200
vllm/v1/worker/gpu/attn_utils.py
Normal file
@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def get_kv_cache_spec(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
assert attn_module.attn_type == AttentionType.DECODER
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
attn_backends: dict[str, AttentionBackend] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
|
||||
flashinfer_workspace: torch.Tensor | None = None
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
any_layer_name = next(iter(layer_names))
|
||||
|
||||
attn_backend = attn_layers[any_layer_name].get_attn_backend()
|
||||
for layer_name in layer_names:
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||
|
||||
if "FLASHINFER" in attn_backend.get_name():
|
||||
if flashinfer_workspace is None:
|
||||
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
|
||||
else:
|
||||
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
|
||||
def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
raw_tensor = raw_tensor.view(dtype)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def init_kv_cache(
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
forward_context: dict[str, Any],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_computed_tokens_cpu: torch.Tensor,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
|
||||
seq_lens_gpu = seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = seq_lens.cpu[:num_reqs]
|
||||
max_seq_len = int(seq_lens.np[:num_reqs].max())
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
312
vllm/v1/worker/gpu/block_table.py
Normal file
312
vllm/v1/worker/gpu/block_table.py
Normal file
@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class BlockTables:
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
block_table = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
|
||||
# Block tables used for model's forward pass.
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(block_table) for block_table in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
|
||||
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_sizes_tensor = torch.tensor(
|
||||
self.block_sizes, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.num_blocks = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.slot_mappings = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
|
||||
self.cu_num_new_blocks = self._make_buffer(
|
||||
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||
ptrs_tensor_cpu = torch.tensor(
|
||||
[t.data_ptr() for t in x],
|
||||
dtype=torch.uint64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
# [num_reqs]
|
||||
req_indices: list[int],
|
||||
# [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks: list[list[int]],
|
||||
# [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids: list[list[int]],
|
||||
# [num_reqs]
|
||||
overwrite: list[bool],
|
||||
) -> None:
|
||||
num_reqs = len(req_indices)
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
|
||||
|
||||
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
|
||||
# no clear upper bound to the number of new blocks in a single step.
|
||||
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
|
||||
# guarantee that the buffer is not freed before the copy is completed.
|
||||
self.new_block_ids_cpu = torch.empty(
|
||||
self.num_kv_cache_groups,
|
||||
max(len(x) for x in new_block_ids),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
new_block_ids_np = self.new_block_ids_cpu.numpy()
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
self.cu_num_new_blocks.copy_to_gpu(),
|
||||
self.cu_num_new_blocks.gpu.stride(0),
|
||||
new_block_ids_gpu,
|
||||
new_block_ids_gpu.stride(0),
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.block_table_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
|
||||
def gather_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
idx_mapping,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _append_block_ids_kernel(
|
||||
# Inputs
|
||||
req_indices, # [num_reqs]
|
||||
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks_stride,
|
||||
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids_stride,
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
|
||||
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
if do_overwrite:
|
||||
dst_start_idx = 0
|
||||
else:
|
||||
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
|
||||
dst_end_idx = dst_start_idx + num_new_blocks
|
||||
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
row_ptr = block_table_ptr + req_idx * block_table_stride
|
||||
|
||||
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
|
||||
for i in range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(
|
||||
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
|
||||
)
|
||||
tl.store(
|
||||
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gather_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
req_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_idx * block_table_stride + block_indices
|
||||
)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
return tl.multiple_of(ptr, 16)
|
||||
175
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
175
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.padded_sizes = self._init_padded_sizes()
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
|
||||
def _init_padded_sizes(self) -> dict[int, int]:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
# CUDA graphs are disabled.
|
||||
return {}
|
||||
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
|
||||
raise NotImplementedError("Piecewise CUDA graphs are not supported")
|
||||
if self.compilation_config.level != 0:
|
||||
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
|
||||
|
||||
padded_sizes: dict[int, int] = {}
|
||||
assert len(self.cudagraph_sizes) > 0
|
||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
||||
for x in self.cudagraph_sizes:
|
||||
if i <= x:
|
||||
padded_sizes[i] = x
|
||||
break
|
||||
return padded_sizes
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.padded_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
|
||||
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
|
||||
# Prefill is included.
|
||||
return None
|
||||
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
batch_size: int,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert batch_size not in self.graphs
|
||||
|
||||
# Prepare dummy inputs.
|
||||
input_ids = input_buffers.input_ids.gpu[:batch_size]
|
||||
positions = input_buffers.positions.gpu[:batch_size]
|
||||
|
||||
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
|
||||
input_buffers.query_start_loc.np[batch_size:] = batch_size
|
||||
input_buffers.query_start_loc.copy_to_gpu()
|
||||
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
|
||||
input_buffers.seq_lens.np[batch_size:] = 0
|
||||
input_buffers.seq_lens.copy_to_gpu()
|
||||
|
||||
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :batch_size]
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
num_reqs=batch_size,
|
||||
num_tokens=batch_size,
|
||||
query_start_loc=input_buffers.query_start_loc,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
self.graphs[batch_size] = graph
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert self.needs_capture()
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
for batch_size in sizes_to_capture:
|
||||
self.capture_graph(
|
||||
batch_size,
|
||||
model,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
kv_cache_config,
|
||||
)
|
||||
|
||||
def run(self, batch_size: int) -> torch.Tensor:
|
||||
assert batch_size in self.graphs
|
||||
self.graphs[batch_size].replay()
|
||||
return self.hidden_states[:batch_size]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
gc.unfreeze()
|
||||
58
vllm/v1/worker/gpu/dist_utils.py
Normal file
58
vllm/v1/worker/gpu/dist_utils.py
Normal file
@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
|
||||
def evenly_split(
|
||||
n: int,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> tuple[int, int]:
|
||||
q = n // tp_size
|
||||
r = n % tp_size
|
||||
start = q * tp_rank + min(tp_rank, r)
|
||||
end = start + q + (1 if tp_rank < r else 0)
|
||||
return start, end
|
||||
|
||||
|
||||
def pad_and_all_gather(
|
||||
x: torch.Tensor,
|
||||
padded_size: int,
|
||||
) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
if n != padded_size:
|
||||
padded_x = torch.empty(
|
||||
(padded_size, *x.shape[1:]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
padded_x[:n] = x
|
||||
else:
|
||||
padded_x = x
|
||||
|
||||
x = tensor_model_parallel_all_gather(padded_x)
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_sampler_output(
|
||||
sampler_output: SamplerOutput,
|
||||
num_reqs: int,
|
||||
tp_size: int,
|
||||
) -> SamplerOutput:
|
||||
n = (num_reqs + tp_size - 1) // tp_size
|
||||
sampler_output.sampled_token_ids = pad_and_all_gather(
|
||||
sampler_output.sampled_token_ids, n)[:num_reqs]
|
||||
|
||||
# TODO(woosuk): 3 small all-gathers, could be merged into one.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
if logprobs_tensors is not None:
|
||||
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
|
||||
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
|
||||
logprobs_tensors.logprobs = pad_and_all_gather(
|
||||
logprobs_tensors.logprobs, n)[:num_reqs]
|
||||
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
|
||||
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
|
||||
return sampler_output
|
||||
257
vllm/v1/worker/gpu/input_batch.py
Normal file
257
vllm/v1/worker/gpu/input_batch.py
Normal file
@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# [num_reqs]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_np: np.ndarray
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
|
||||
# [num_reqs]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
|
||||
input_buffers.query_start_loc.np[0] = 0
|
||||
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
|
||||
num_scheduled_tokens
|
||||
)
|
||||
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
|
||||
# seq_len equals to query_len
|
||||
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
|
||||
input_buffers.seq_lens.np[num_reqs:] = 0
|
||||
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
|
||||
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
|
||||
|
||||
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
|
||||
positions = input_buffers.positions.copy_to_gpu(num_tokens)
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=None,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int64[:], # positions
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def _prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray,
|
||||
prefill_token_ids: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
input_ids: CpuGpuBuffer,
|
||||
positions: CpuGpuBuffer,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
_prepare_inputs(
|
||||
idx_mapping,
|
||||
prefill_token_ids,
|
||||
num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
positions.np,
|
||||
query_start_loc.np,
|
||||
seq_lens.np,
|
||||
)
|
||||
input_ids.copy_to_gpu(num_tokens)
|
||||
positions.copy_to_gpu(num_tokens)
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens.copy_to_gpu()
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _combine_last_token_ids_kernel(
|
||||
input_ids_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
prefill_len_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if seq_len <= prefill_len:
|
||||
# Handling prefill tokens.
|
||||
return
|
||||
|
||||
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
|
||||
|
||||
def combine_last_token_ids(
|
||||
input_ids: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
last_token_ids: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
_combine_last_token_ids_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_token_ids,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
prefill_len,
|
||||
)
|
||||
return input_ids
|
||||
688
vllm/v1/worker/gpu/model_runner.py
Normal file
688
vllm/v1/worker/gpu/model_runner.py
Normal file
@ -0,0 +1,688 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self.kv_cache_dtype = self.dtype
|
||||
if self.cache_config.cache_dtype != "auto":
|
||||
# Quantized KV cache.
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.hidden_size = self.model_config.get_hidden_size()
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate",)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
with DeviceMemoryProfiler() as m:
|
||||
model_loader = get_model_loader(self.vllm_config.load_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
"Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self):
|
||||
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
block_sizes = kv_cache_config.block_sizes
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=block_sizes,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.kv_cache_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
init_kv_cache(
|
||||
self.kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_cache_config,
|
||||
self.attn_backends,
|
||||
self.device,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
input_batch: InputBatch | None = None,
|
||||
skip_attn: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if input_batch is None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if not skip_attn:
|
||||
block_tables = self.block_tables.gather_block_tables(
|
||||
input_batch.idx_mapping
|
||||
)
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=None,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config, input_batch.num_scheduled_tokens
|
||||
):
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = hidden_states.shape[0]
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler.sample(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=self.max_num_reqs,
|
||||
num_tokens=self.max_num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
input_batch=input_batch,
|
||||
skip_attn=True,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> int:
|
||||
if not self.cudagraph_manager.needs_capture():
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"ensure `cudagraph_mode` was not manually set to `NONE`"
|
||||
)
|
||||
return 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
input_buffers=self.input_buffers,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info(
|
||||
"Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time,
|
||||
cuda_graph_size / (1 << 30),
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
def warmup_for_prefill(self) -> None:
|
||||
# For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation.
|
||||
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
|
||||
self._dummy_run(self.max_num_tokens, skip_attn=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
# TODO(woosuk): Change SchedulerOutput.
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks = tuple(
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
)
|
||||
new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_len=len(new_req_data.prompt_token_ids),
|
||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(new_req_data.block_ids):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
req_new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if req_new_block_ids is not None:
|
||||
req_indices.append(req_index)
|
||||
for group_id, block_ids in enumerate(req_new_block_ids):
|
||||
x = cu_num_new_blocks[group_id][-1]
|
||||
cu_num_new_blocks[group_id].append(x + len(block_ids))
|
||||
new_block_ids[group_id].extend(block_ids)
|
||||
overwrite.append(False)
|
||||
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
cu_num_new_blocks=cu_num_new_blocks,
|
||||
new_block_ids=new_block_ids,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
use_cudagraph: bool,
|
||||
padded_num_tokens: int | None,
|
||||
) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get,
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
)
|
||||
if use_cudagraph:
|
||||
assert padded_num_tokens is not None
|
||||
num_tokens_after_padding = padded_num_tokens
|
||||
else:
|
||||
num_tokens_after_padding = num_tokens
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
idx_mapping = self.input_buffers.idx_mapping
|
||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.prefill_token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
self.input_buffers.input_ids,
|
||||
self.input_buffers.positions,
|
||||
self.input_buffers.query_start_loc,
|
||||
self.input_buffers.seq_lens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
|
||||
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens.
|
||||
combine_last_token_ids(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens_gpu,
|
||||
self.req_states.prefill_len.copy_to_gpu(),
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
|
||||
)
|
||||
|
||||
num_computed_tokens_cpu = torch.from_numpy(
|
||||
self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
)
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
sampler_output = self.sampler.sample(logits, sampling_metadata)
|
||||
return sampler_output
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
) -> dict[str, LogprobsTensors]:
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# No request asks for prompt logprobs.
|
||||
return {}
|
||||
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
|
||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||
# needed for prompt logprobs.
|
||||
includes_prompt = num_computed_tokens < prompt_lens - 1
|
||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||
# logprobs must have been computed before preemption. Skip.
|
||||
resumed_after_prompt = (
|
||||
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
|
||||
)
|
||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
return {}
|
||||
|
||||
# Just to be safe, clone the input ids.
|
||||
n = input_batch.num_tokens
|
||||
# Shift the input ids by one.
|
||||
token_ids = torch.empty_like(input_batch.input_ids[:n])
|
||||
token_ids[: n - 1] = input_batch.input_ids[1:n]
|
||||
# To avoid out-of-bound access, set the last token id to 0.
|
||||
token_ids[n - 1] = 0
|
||||
|
||||
# Handle chunked prompts.
|
||||
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
|
||||
is_prompt_chunked = seq_lens < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
if not is_prompt_chunked[i]:
|
||||
continue
|
||||
# The prompt is chunked. Get the next prompt token.
|
||||
req_idx = input_batch.idx_mapping_np[i]
|
||||
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
|
||||
idx = int(query_start_loc[i + 1] - 1)
|
||||
# Set the next prompt token.
|
||||
# NOTE(woosuk): This triggers a GPU operation.
|
||||
token_ids[idx] = next_prompt_token
|
||||
|
||||
# NOTE(woosuk): We mask out logprobs for negative tokens.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
|
||||
torch.relu(token_ids),
|
||||
hidden_states[:n],
|
||||
self.model.compute_logits,
|
||||
)
|
||||
prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0)
|
||||
|
||||
prompt_token_ids = token_ids.unsqueeze(-1)
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc[i]
|
||||
end_idx = query_start_loc[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||
)
|
||||
|
||||
req_extra_data = self.req_states.extra_data[req_id]
|
||||
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
|
||||
if is_prompt_chunked[i]:
|
||||
# Prompt is chunked. Do not return the logprobs yet.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
continue
|
||||
|
||||
if prompt_logprobs_list:
|
||||
# Merge the in-progress logprobs.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=torch.cat(
|
||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||
),
|
||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||
selected_token_ranks=torch.cat(
|
||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||
),
|
||||
)
|
||||
prompt_logprobs_list.clear()
|
||||
|
||||
prompt_logprobs_dict[req_id] = logprobs
|
||||
return prompt_logprobs_dict
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors],
|
||||
input_batch: InputBatch,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
# Store the last sampled token ids.
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
|
||||
sampler_output.sampled_token_ids
|
||||
)
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
is_chunked_prefilling = (
|
||||
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
|
||||
)
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
# Increment the number of computed tokens.
|
||||
self.req_states.num_computed_tokens[idx_mapping_np] += (
|
||||
input_batch.num_scheduled_tokens
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
async_output = AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampler_output=sampler_output,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
copy_stream=self.output_copy_stream,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
intermediate_tensors: Any | None = None,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
assert intermediate_tensors is None
|
||||
|
||||
with async_barrier(
|
||||
self.input_prep_event if self.use_async_scheduling else None
|
||||
):
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
padded_num_tokens = self.cudagraph_manager.get_cudagraph_size(
|
||||
scheduler_output
|
||||
)
|
||||
use_cudagraph = padded_num_tokens is not None
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
use_cudagraph,
|
||||
padded_num_tokens,
|
||||
)
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
idx_mapping_np, pos
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Run model.
|
||||
if use_cudagraph:
|
||||
# Run CUDA graph.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
hidden_states = self.cudagraph_manager.run(padded_num_tokens)
|
||||
else:
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
):
|
||||
# Run PyTorch model in eager mode.
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
sampler_output = self.sample(hidden_states, input_batch, sampling_metadata)
|
||||
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
|
||||
output = self.postprocess(
|
||||
sampler_output,
|
||||
sampling_metadata,
|
||||
prompt_logprobs_dict,
|
||||
input_batch,
|
||||
)
|
||||
return output
|
||||
302
vllm/v1/worker/gpu/sampler.py
Normal file
302
vllm/v1/worker/gpu/sampler.py
Normal file
@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def sample_token(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
return_logits: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
is_greedy = sampling_metadata.temperature == 0
|
||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits = logits / temp.view(-1, 1)
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
probs,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
sampled = sampled.to(torch.int64)
|
||||
return sampled, logits if return_logits else None
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
sampled, logits = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=True
|
||||
)
|
||||
else:
|
||||
assert self.logprobs_mode == "raw_logprobs"
|
||||
sampled, _ = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=False
|
||||
)
|
||||
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits,
|
||||
sampling_metadata.max_num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
else:
|
||||
sampled, _ = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=False
|
||||
)
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
probs_ptr,
|
||||
probs_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
return
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
block_id = tl.program_id(1)
|
||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.rand(gumbel_seed, r_offset)
|
||||
|
||||
# NOTE(woosuk): This logic makes sure q is not 0.
|
||||
RMAX = 0.9999999403953552
|
||||
RMAX_LOG = -5.960464477539063e-08
|
||||
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
|
||||
q = -1.0 * q
|
||||
|
||||
p = tl.load(
|
||||
probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size
|
||||
)
|
||||
p = p / q
|
||||
tl.store(
|
||||
probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size
|
||||
)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
probs: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = probs.shape
|
||||
_gumbel_sample_kernel[(num_reqs,)](
|
||||
probs,
|
||||
probs.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
sampled = probs.argmax(dim=-1)
|
||||
return sampled
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32)
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations are in float32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
263
vllm/v1/worker/gpu/states.py
Normal file
263
vllm/v1/worker/gpu/states.py
Normal file
@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
# Currently, they are disabled because of memory usage.
|
||||
top_p = None
|
||||
top_k = None
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
self.extra_data: dict[str, ExtraData] = {}
|
||||
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.prefill_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# LoRA.
|
||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_len: int,
|
||||
prefill_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
self.extra_data[req_id] = ExtraData(lora_request)
|
||||
|
||||
self.prompt_len[req_idx] = prompt_len
|
||||
prefill_len = len(prefill_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
|
||||
self.num_tokens[req_idx] = prefill_len
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
if lora_request is not None:
|
||||
self.lora_ids[req_idx] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_idx] = NO_LORA_ID
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
else:
|
||||
num_logprobs = -1
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
# For now, only support prompt logprobs for the prompt tokens.
|
||||
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.extra_data.pop(req_id, None)
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature.np[idx_mapping]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p.np[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k.np[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds.np[idx_mapping]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = int(np.max(num_logprobs))
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.extra_data[req_id].lora_request
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
|
||||
class Param:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.buffer = CpuGpuBuffer(
|
||||
size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.np = np.zeros_like(self.buffer.np)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
@ -42,7 +42,9 @@ from vllm.v1.outputs import (
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
@ -494,6 +496,8 @@ class Worker(WorkerBase):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
Reference in New Issue
Block a user