update
This commit is contained in:
@ -10,6 +10,7 @@ torchaudio==2.9.0
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
|
||||
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.4.1
|
||||
apache-tvm-ffi==0.1.0b15
|
||||
|
||||
@ -34,6 +34,7 @@ else:
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
prefill_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
@ -51,6 +52,7 @@ class NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prefill_token_ids=request._all_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
@ -173,6 +175,7 @@ class SchedulerOutput:
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
preempted_req_ids: set[str]
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
|
||||
@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
|
||||
# Construct the scheduler output.
|
||||
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
|
||||
scheduled_resumed_reqs = []
|
||||
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface):
|
||||
req.num_computed_tokens : req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_token_ids.append(token_ids)
|
||||
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
|
||||
if idx >= num_running_reqs:
|
||||
assert not scheduled_in_prev_step
|
||||
resumed_req_ids.add(req_id)
|
||||
if not scheduled_in_prev_step:
|
||||
all_token_ids[req_id] = req.all_token_ids[
|
||||
: req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
|
||||
)
|
||||
@ -902,7 +898,6 @@ class Scheduler(SchedulerInterface):
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
num_sampled_tokens = model_runner_output.num_sampled_tokens
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
@ -15,7 +15,6 @@ else:
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: np.ndarray
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -135,13 +134,14 @@ class KVConnectorOutput:
|
||||
class ModelRunnerOutput:
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: np.ndarray | None
|
||||
num_sampled_tokens: np.ndarray | None
|
||||
sampled_token_ids: list[list[int]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -186,8 +186,8 @@ class DraftTokenIds:
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
sampled_token_ids=None,
|
||||
num_sampled_tokens=None,
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -1,21 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
)
|
||||
|
||||
|
||||
class AsyncOutput(AsyncModelRunnerOutput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_stream = copy_stream
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
@ -23,26 +30,46 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.copy_stream.wait_stream(default_stream)
|
||||
|
||||
# NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes.
|
||||
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
|
||||
"cpu", non_blocking=True)
|
||||
x = sampler_output.logprobs_tensors
|
||||
if x is not None:
|
||||
self.logprobs_tensors = LogprobsTensors(
|
||||
logprob_token_ids=x.logprob_token_ids.to(
|
||||
"cpu", non_blocking=True),
|
||||
logprobs=x.logprobs.to("cpu", non_blocking=True),
|
||||
selected_token_ranks=x.selected_token_ranks.to(
|
||||
"cpu", non_blocking=True),
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.copy_event.record()
|
||||
self.prompt_logprobs_dict = {}
|
||||
if self.model_runner_output.prompt_logprobs_dict:
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
self.copy_event.record(self.copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
self.model_runner_output.sampled_token_ids = (
|
||||
self.sampled_token_ids.numpy())
|
||||
|
||||
# NOTE(woosuk): The following code ensures compatibility with OSS vLLM.
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids_np = self.sampled_token_ids.numpy()
|
||||
sampled_token_ids = sampled_token_ids_np.tolist()
|
||||
for i, tokens in enumerate(sampled_token_ids):
|
||||
del tokens[self.num_sampled_tokens[i] :]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = (
|
||||
self.logprobs_tensors.tolists())
|
||||
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
|
||||
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_barrier(event: torch.cuda.Event | None):
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if event is not None:
|
||||
event.record()
|
||||
|
||||
@ -7,9 +7,17 @@ import torch
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
@ -18,7 +26,6 @@ def get_kv_cache_spec(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
@ -31,7 +38,6 @@ def get_kv_cache_spec(
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@ -39,7 +45,6 @@ def get_kv_cache_spec(
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
@ -52,6 +57,7 @@ def init_attn_backend(
|
||||
attn_backends: dict[str, AttentionBackend] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
|
||||
flashinfer_workspace: torch.Tensor | None = None
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
@ -67,7 +73,13 @@ def init_attn_backend(
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder)
|
||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||
|
||||
if "FLASHINFER" in attn_backend.get_name():
|
||||
if flashinfer_workspace is None:
|
||||
flashinfer_workspace = attn_metadata_builder.get_workspace_buffer()
|
||||
else:
|
||||
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
|
||||
@ -77,9 +89,7 @@ def _allocate_kv_cache(
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=device)
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
@ -87,8 +97,9 @@ def _allocate_kv_cache(
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()
|
||||
), "Some layers are not correctly initialized"
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
@ -103,17 +114,19 @@ def _reshape_kv_cache(
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
@ -132,8 +145,56 @@ def init_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
):
|
||||
) -> None:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
|
||||
attn_backends)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_computed_tokens_cpu: torch.Tensor,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
|
||||
seq_lens_gpu = seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = seq_lens.cpu[:num_reqs]
|
||||
max_seq_len = int(seq_lens.np[:num_reqs].max())
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
|
||||
@ -6,14 +6,13 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class BlockTables:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
@ -50,44 +49,48 @@ class BlockTables:
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(block_table) for block_table in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(
|
||||
self.input_block_tables)
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
|
||||
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
device=self.device,
|
||||
)
|
||||
self.block_sizes_tensor = torch.tensor(
|
||||
self.block_sizes, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.num_blocks = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.slot_mappings = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
|
||||
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
|
||||
self.max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.cu_num_new_blocks = self._make_buffer(
|
||||
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
ptrs_tensor_cpu = torch.tensor(
|
||||
[t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
@ -105,7 +108,7 @@ class BlockTables:
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
|
||||
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
|
||||
|
||||
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
|
||||
# no clear upper bound to the number of new blocks in a single step.
|
||||
@ -120,9 +123,8 @@ class BlockTables:
|
||||
)
|
||||
new_block_ids_np = self.new_block_ids_cpu.numpy()
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
@ -135,7 +137,7 @@ class BlockTables:
|
||||
self.block_table_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
|
||||
def gather_block_tables(
|
||||
@ -150,10 +152,9 @@ class BlockTables:
|
||||
self.block_table_strides,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return tuple(block_table[:num_reqs]
|
||||
for block_table in self.input_block_tables)
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
@ -174,7 +175,7 @@ class BlockTables:
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
@ -201,8 +202,7 @@ def _append_block_ids_kernel(
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
|
||||
group_id * cu_num_new_blocks_stride)
|
||||
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
|
||||
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
@ -220,15 +220,15 @@ def _append_block_ids_kernel(
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
row_ptr = block_table_ptr + req_idx * block_table_stride
|
||||
|
||||
group_new_block_ids_ptr = (new_block_ids_ptr +
|
||||
group_id * new_block_ids_stride)
|
||||
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
|
||||
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
|
||||
for i in range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
|
||||
mask=offset < num_new_blocks)
|
||||
tl.store(row_ptr + dst_start_idx + offset,
|
||||
block_ids,
|
||||
mask=offset < num_new_blocks)
|
||||
block_ids = tl.load(
|
||||
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
|
||||
)
|
||||
tl.store(
|
||||
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -282,11 +282,9 @@ def _compute_slot_mappings_kernel(
|
||||
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset,
|
||||
PAD_ID,
|
||||
mask=offset < max_num_tokens)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
@ -295,12 +293,13 @@ def _compute_slot_mappings_kernel(
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
|
||||
for i in range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(block_table_ptr +
|
||||
req_idx * block_table_stride + block_indices)
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_idx * block_table_stride + block_indices
|
||||
)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
175
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
175
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.padded_sizes = self._init_padded_sizes()
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
|
||||
def _init_padded_sizes(self) -> dict[int, int]:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
# CUDA graphs are disabled.
|
||||
return {}
|
||||
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
|
||||
raise NotImplementedError("Piecewise CUDA graphs are not supported")
|
||||
if self.compilation_config.level != 0:
|
||||
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
|
||||
|
||||
padded_sizes: dict[int, int] = {}
|
||||
assert len(self.cudagraph_sizes) > 0
|
||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
||||
for x in self.cudagraph_sizes:
|
||||
if i <= x:
|
||||
padded_sizes[i] = x
|
||||
break
|
||||
return padded_sizes
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.padded_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
|
||||
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
|
||||
# Prefill is included.
|
||||
return None
|
||||
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
batch_size: int,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert batch_size not in self.graphs
|
||||
|
||||
# Prepare dummy inputs.
|
||||
input_ids = input_buffers.input_ids.gpu[:batch_size]
|
||||
positions = input_buffers.positions.gpu[:batch_size]
|
||||
|
||||
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
|
||||
input_buffers.query_start_loc.np[batch_size:] = batch_size
|
||||
input_buffers.query_start_loc.copy_to_gpu()
|
||||
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
|
||||
input_buffers.seq_lens.np[batch_size:] = 0
|
||||
input_buffers.seq_lens.copy_to_gpu()
|
||||
|
||||
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :batch_size]
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
num_reqs=batch_size,
|
||||
num_tokens=batch_size,
|
||||
query_start_loc=input_buffers.query_start_loc,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
self.graphs[batch_size] = graph
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert self.needs_capture()
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
for batch_size in sizes_to_capture:
|
||||
self.capture_graph(
|
||||
batch_size,
|
||||
model,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
kv_cache_config,
|
||||
)
|
||||
|
||||
def run(self, batch_size: int) -> torch.Tensor:
|
||||
assert batch_size in self.graphs
|
||||
self.graphs[batch_size].replay()
|
||||
return self.hidden_states[:batch_size]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
gc.unfreeze()
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -16,11 +15,12 @@ from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
@ -32,20 +32,17 @@ class InputBuffers:
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
@ -54,17 +51,23 @@ class InputBatch:
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# [num_reqs]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
# [num_reqs]
|
||||
is_chunked_prefilling: np.ndarray
|
||||
|
||||
# [max_num_batched_tokens]
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_np: np.ndarray
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [max_num_batched_tokens]
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
@ -78,23 +81,34 @@ class InputBatch:
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.tensor(idx_mapping_np, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs,
|
||||
num_tokens // num_reqs,
|
||||
dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
|
||||
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
||||
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
|
||||
attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = torch.arange(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
|
||||
input_buffers.query_start_loc.np[0] = 0
|
||||
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
|
||||
num_scheduled_tokens
|
||||
)
|
||||
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
|
||||
# seq_len equals to query_len
|
||||
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
|
||||
input_buffers.seq_lens.np[num_reqs:] = 0
|
||||
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
|
||||
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
|
||||
|
||||
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
|
||||
positions = input_buffers.positions.copy_to_gpu(num_tokens)
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@ -103,10 +117,13 @@ class InputBatch:
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
attn_metadata=None,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
@ -130,14 +147,14 @@ class InputBatch:
|
||||
cache=True,
|
||||
)
|
||||
def _prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
@ -161,14 +178,14 @@ def _prepare_inputs(
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray,
|
||||
prompt_token_ids: np.ndarray,
|
||||
prefill_token_ids: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
input_ids: CpuGpuBuffer,
|
||||
@ -176,10 +193,10 @@ def prepare_inputs(
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_tokens: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
) -> None:
|
||||
_prepare_inputs(
|
||||
idx_mapping,
|
||||
prompt_token_ids,
|
||||
prefill_token_ids,
|
||||
num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
@ -194,11 +211,7 @@ def prepare_inputs(
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens.copy_to_gpu()
|
||||
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
max_query_len = int(num_scheduled_tokens.max())
|
||||
max_seq_len = int(seq_lens.np[:num_reqs].max())
|
||||
return max_query_len, max_seq_len
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -208,21 +221,18 @@ def _combine_last_token_ids_kernel(
|
||||
last_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
num_tokens_ptr,
|
||||
prefill_len_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
|
||||
if seq_len < num_tokens:
|
||||
# Chunked prefilling.
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if seq_len <= prefill_len:
|
||||
# Handling prefill tokens.
|
||||
return
|
||||
|
||||
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
|
||||
if last_token_id == -1:
|
||||
return
|
||||
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
|
||||
@ -233,15 +243,15 @@ def combine_last_token_ids(
|
||||
last_token_ids: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_tokens: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
_combine_last_token_ids_kernel[(num_reqs, )](
|
||||
_combine_last_token_ids_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_token_ids,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
num_tokens,
|
||||
prefill_len,
|
||||
)
|
||||
return input_ids
|
||||
|
||||
@ -3,41 +3,52 @@
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
||||
init_attn_backend, init_kv_cache)
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
|
||||
evenly_split)
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -61,17 +72,19 @@ class GPUModelRunner:
|
||||
if self.cache_config.cache_dtype != "auto":
|
||||
# Quantized KV cache.
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.hidden_size = self.model_config.get_hidden_size()
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
assert self.use_async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream()
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@ -84,29 +97,46 @@ class GPUModelRunner:
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate", )
|
||||
return ("generate",)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
with DeviceMemoryProfiler() as m:
|
||||
model_loader = get_model_loader(self.vllm_config.load_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load)
|
||||
logger.info(
|
||||
"Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@ -143,32 +173,60 @@ class GPUModelRunner:
|
||||
self.device,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
input_batch: Optional[InputBatch] = None,
|
||||
input_batch: InputBatch | None = None,
|
||||
skip_attn: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if input_batch is None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=min(num_tokens, self.max_num_reqs),
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if not skip_attn:
|
||||
block_tables = self.block_tables.gather_block_tables(
|
||||
input_batch.idx_mapping
|
||||
)
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=None,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
|
||||
with set_forward_context(
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config, input_batch.num_scheduled_tokens
|
||||
):
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -179,35 +237,80 @@ class GPUModelRunner:
|
||||
device=self.device,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
self.sampler.sample(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=self.max_num_reqs,
|
||||
num_tokens=self.max_num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
input_batch=input_batch,
|
||||
skip_attn=True,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> int:
|
||||
if not self.cudagraph_manager.needs_capture():
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"ensure `cudagraph_mode` was not manually set to `NONE`"
|
||||
)
|
||||
return 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
input_buffers=self.input_buffers,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info(
|
||||
"Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time,
|
||||
cuda_graph_size / (1 << 30),
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
def warmup_for_prefill(self) -> None:
|
||||
# For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation.
|
||||
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
|
||||
self._dummy_run(self.max_num_tokens, skip_attn=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# for req_id in scheduler_output.preempted_req_ids:
|
||||
# self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
# TODO(woosuk): Change SchedulerOutput.
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks = tuple(
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
new_block_ids = tuple(
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
)
|
||||
new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests.
|
||||
@ -215,9 +318,11 @@ class GPUModelRunner:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt_len=len(new_req_data.prompt_token_ids),
|
||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
@ -250,21 +355,30 @@ class GPUModelRunner:
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
use_cudagraph: bool,
|
||||
padded_num_tokens: int | None,
|
||||
) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get)
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get,
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
|
||||
dtype=np.int32)
|
||||
|
||||
# TODO(woosuk): Support CUDA graphs.
|
||||
num_tokens_after_padding = num_tokens
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
)
|
||||
if use_cudagraph:
|
||||
assert padded_num_tokens is not None
|
||||
num_tokens_after_padding = padded_num_tokens
|
||||
else:
|
||||
num_tokens_after_padding = num_tokens
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
@ -277,9 +391,9 @@ class GPUModelRunner:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
max_query_len, max_seq_len = prepare_inputs(
|
||||
prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.prompt_token_ids,
|
||||
self.req_states.prefill_token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
self.input_buffers.input_ids,
|
||||
@ -290,10 +404,9 @@ class GPUModelRunner:
|
||||
)
|
||||
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
|
||||
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs]
|
||||
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens.
|
||||
@ -303,56 +416,33 @@ class GPUModelRunner:
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens_gpu,
|
||||
self.req_states.num_tokens.copy_to_gpu(),
|
||||
self.req_states.prefill_len.copy_to_gpu(),
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
|
||||
)
|
||||
|
||||
num_computed_tokens_cpu = torch.from_numpy(
|
||||
self.req_states.num_computed_tokens[idx_mapping_np])
|
||||
|
||||
# Whether the request is chunked-prefilling or not.
|
||||
is_chunked_prefilling = (
|
||||
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
|
||||
self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
)
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=None,
|
||||
num_logits_indices=num_logits_indices,
|
||||
causal=True,
|
||||
encoder_seq_lens=None,
|
||||
)
|
||||
|
||||
attn_metadata_builder = self.attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
|
||||
@ -364,7 +454,10 @@ class GPUModelRunner:
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -375,102 +468,221 @@ class GPUModelRunner:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
num_reqs = logits.shape[0]
|
||||
|
||||
# When the batch size is large enough, use DP sampler.
|
||||
tp_group = get_tp_group()
|
||||
tp_size = tp_group.world_size
|
||||
n = (num_reqs + tp_size - 1) // tp_size
|
||||
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
|
||||
if use_dp_sampler:
|
||||
# NOTE(woosuk): Make sure that no rank gets zero requests.
|
||||
tp_rank = tp_group.rank
|
||||
start, end = evenly_split(num_reqs, tp_size, tp_rank)
|
||||
logits = logits[start:end]
|
||||
pos = pos[start:end]
|
||||
idx_mapping_np = idx_mapping_np[start:end]
|
||||
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
idx_mapping_np, pos)
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
needs_prompt_logprobs = np.any(
|
||||
self.req_states.needs_prompt_logprobs[idx_mapping_np])
|
||||
assert not needs_prompt_logprobs
|
||||
|
||||
if use_dp_sampler:
|
||||
# All-gather the outputs.
|
||||
sampler_output = all_gather_sampler_output(
|
||||
sampler_output,
|
||||
num_reqs,
|
||||
tp_size,
|
||||
)
|
||||
sampler_output = self.sampler.sample(logits, sampling_metadata)
|
||||
return sampler_output
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
) -> dict[str, LogprobsTensors]:
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# No request asks for prompt logprobs.
|
||||
return {}
|
||||
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
|
||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||
# needed for prompt logprobs.
|
||||
includes_prompt = num_computed_tokens < prompt_lens - 1
|
||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||
# logprobs must have been computed before preemption. Skip.
|
||||
resumed_after_prompt = (
|
||||
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
|
||||
)
|
||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
return {}
|
||||
|
||||
# Just to be safe, clone the input ids.
|
||||
n = input_batch.num_tokens
|
||||
# Shift the input ids by one.
|
||||
token_ids = torch.empty_like(input_batch.input_ids[:n])
|
||||
token_ids[: n - 1] = input_batch.input_ids[1:n]
|
||||
# To avoid out-of-bound access, set the last token id to 0.
|
||||
token_ids[n - 1] = 0
|
||||
|
||||
# Handle chunked prompts.
|
||||
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
|
||||
is_prompt_chunked = seq_lens < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
if not is_prompt_chunked[i]:
|
||||
continue
|
||||
# The prompt is chunked. Get the next prompt token.
|
||||
req_idx = input_batch.idx_mapping_np[i]
|
||||
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
|
||||
idx = int(query_start_loc[i + 1] - 1)
|
||||
# Set the next prompt token.
|
||||
# NOTE(woosuk): This triggers a GPU operation.
|
||||
token_ids[idx] = next_prompt_token
|
||||
|
||||
# NOTE(woosuk): We mask out logprobs for negative tokens.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
|
||||
torch.relu(token_ids),
|
||||
hidden_states[:n],
|
||||
self.model.compute_logits,
|
||||
)
|
||||
prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0)
|
||||
|
||||
prompt_token_ids = token_ids.unsqueeze(-1)
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc[i]
|
||||
end_idx = query_start_loc[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||
)
|
||||
|
||||
req_extra_data = self.req_states.extra_data[req_id]
|
||||
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
|
||||
if is_prompt_chunked[i]:
|
||||
# Prompt is chunked. Do not return the logprobs yet.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
continue
|
||||
|
||||
if prompt_logprobs_list:
|
||||
# Merge the in-progress logprobs.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=torch.cat(
|
||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||
),
|
||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||
selected_token_ranks=torch.cat(
|
||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||
),
|
||||
)
|
||||
prompt_logprobs_list.clear()
|
||||
|
||||
prompt_logprobs_dict[req_id] = logprobs
|
||||
return prompt_logprobs_dict
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors],
|
||||
input_batch: InputBatch,
|
||||
) -> AsyncOutput:
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
# Store the last sampled token ids.
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
|
||||
sampler_output.sampled_token_ids)
|
||||
|
||||
sampler_output.sampled_token_ids
|
||||
)
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
is_chunked_prefilling = input_batch.is_chunked_prefilling
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
is_chunked_prefilling = (
|
||||
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
|
||||
)
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
# Increment the number of computed tokens.
|
||||
self.req_states.num_computed_tokens[idx_mapping_np] += (
|
||||
input_batch.num_scheduled_tokens)
|
||||
input_batch.num_scheduled_tokens
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
return AsyncOutput(
|
||||
async_output = AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampler_output=sampler_output,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
copy_stream=self.output_copy_stream,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> AsyncOutput:
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
intermediate_tensors: Any | None = None,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
assert intermediate_tensors is None
|
||||
|
||||
input_batch = self.prepare_inputs(scheduler_output)
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with async_barrier(
|
||||
self.input_prep_event if self.use_async_scheduling else None
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
padded_num_tokens = self.cudagraph_manager.get_cudagraph_size(
|
||||
scheduler_output
|
||||
)
|
||||
use_cudagraph = padded_num_tokens is not None
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
use_cudagraph,
|
||||
padded_num_tokens,
|
||||
)
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
idx_mapping_np, pos
|
||||
)
|
||||
|
||||
sampler_output = self.sample(hidden_states, input_batch)
|
||||
return self.postprocess(sampler_output, input_batch)
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Run model.
|
||||
if use_cudagraph:
|
||||
# Run CUDA graph.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
hidden_states = self.cudagraph_manager.run(padded_num_tokens)
|
||||
else:
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
):
|
||||
# Run PyTorch model in eager mode.
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
sampler_output = self.sample(hidden_states, input_batch, sampling_metadata)
|
||||
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
|
||||
output = self.postprocess(
|
||||
sampler_output,
|
||||
sampling_metadata,
|
||||
prompt_logprobs_dict,
|
||||
input_batch,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -1,61 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import LogprobsMode
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "processed_logprobs",
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
super().__init__()
|
||||
assert logprobs_mode == "processed_logprobs"
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
def sample_token(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Divide logits by temperature, in FP32.
|
||||
logits = apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
return_logits: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
is_greedy = sampling_metadata.temperature == 0
|
||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits = logits / temp.view(-1, 1)
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# Sample the next token (int64).
|
||||
|
||||
sampled = gumbel_sample(
|
||||
probs,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
sampled = sampled.to(torch.int64)
|
||||
return sampled, logits if return_logits else None
|
||||
|
||||
logprobs_tensors = None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
logprobs_tensors = compute_logprobs(
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
sampled, logits = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=True
|
||||
)
|
||||
else:
|
||||
assert self.logprobs_mode == "raw_logprobs"
|
||||
sampled, _ = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=False
|
||||
)
|
||||
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits,
|
||||
num_logprobs,
|
||||
sampling_metadata.max_num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
else:
|
||||
sampled, _ = self.sample_token(
|
||||
logits, sampling_metadata, return_logits=False
|
||||
)
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
@ -69,60 +84,7 @@ class Sampler(nn.Module):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_temp_kernel(
|
||||
logits, # bf16[batch_size, vocab_size]
|
||||
logits_stride,
|
||||
output, # fp32[batch_size, vocab_size]
|
||||
output_stride,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
|
||||
temp = tl.load(temperature + batch_idx)
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply temperature.
|
||||
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
|
||||
temp = 1.0
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
block = block_idx * BLOCK_SIZE + offset
|
||||
|
||||
# Load the logits.
|
||||
x = tl.load(logits + batch_idx * logits_stride + block,
|
||||
mask=block < vocab_size)
|
||||
x = x.to(tl.float32)
|
||||
x = x / temp
|
||||
tl.store(output + batch_idx * output_stride + block,
|
||||
x,
|
||||
mask=block < vocab_size)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
output = torch.empty_like(logits, dtype=torch.float32)
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
output,
|
||||
output.stride(0),
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_gumbel_kernel(
|
||||
def _gumbel_sample_kernel(
|
||||
probs_ptr,
|
||||
probs_stride,
|
||||
seeds_ptr,
|
||||
@ -130,18 +92,17 @@ def _apply_gumbel_kernel(
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
|
||||
if temp < EPSILON:
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
return
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx).to(tl.uint64)
|
||||
pos = tl.load(pos_ptr + req_idx).to(tl.uint64)
|
||||
gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15)
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
block_id = tl.program_id(1)
|
||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
@ -153,42 +114,33 @@ def _apply_gumbel_kernel(
|
||||
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
|
||||
q = -1.0 * q
|
||||
|
||||
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
mask=r_offset < vocab_size)
|
||||
p = tl.load(
|
||||
probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size
|
||||
)
|
||||
p = p / q
|
||||
|
||||
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
p,
|
||||
mask=r_offset < vocab_size)
|
||||
tl.store(
|
||||
probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size
|
||||
)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
# fp32[num_reqs, vocab_size]
|
||||
probs: torch.Tensor,
|
||||
# fp32[num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
pos: torch.Tensor,
|
||||
probs: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
) -> torch.Tensor:
|
||||
num_reqs = probs.shape[0]
|
||||
vocab_size = probs.shape[1]
|
||||
|
||||
# Update the probs in-place.
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
num_reqs, vocab_size = probs.shape
|
||||
_gumbel_sample_kernel[(num_reqs,)](
|
||||
probs,
|
||||
probs.stride(0),
|
||||
seeds,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
# Sample the next token.
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
sampled = probs.argmax(dim=-1)
|
||||
return sampled
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -208,54 +160,31 @@ def _topk_log_softmax_kernel(
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(l, max_val))
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32)
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
e = tl.exp(l - max_val)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations are in float32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
o = l - max_val - lse
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
topk = topk_ids.shape[1]
|
||||
output = torch.empty(
|
||||
batch_size,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size, )](
|
||||
output,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
topk_ids,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024,
|
||||
PADDED_TOPK=triton.next_power_of_2(topk),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
@ -274,14 +203,39 @@ def _ranks_kernel(
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=float("-inf"))
|
||||
n += tl.sum((l > x).to(tl.int32))
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_logprobs(
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
@ -293,31 +247,56 @@ def compute_logprobs(
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_topk_logprobs(
|
||||
logits,
|
||||
logprob_token_ids,
|
||||
)
|
||||
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size, )](
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
|
||||
@ -1,21 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
@ -36,12 +37,14 @@ class SamplingMetadata:
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
top_p[0] = 0.99
|
||||
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
# Currently, they are disabled because of memory usage.
|
||||
top_p = None
|
||||
top_k = None
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@ -53,7 +56,6 @@ class SamplingMetadata:
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
@ -73,15 +75,15 @@ class RequestState:
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
self.extra_data: dict[str, ExtraData] = {}
|
||||
|
||||
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
|
||||
# because of preemption.
|
||||
self.prompt_token_ids = np.zeros(
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.prefill_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled tokens.
|
||||
@ -92,6 +94,10 @@ class RequestState:
|
||||
device=device,
|
||||
)
|
||||
|
||||
# LoRA.
|
||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
@ -104,16 +110,12 @@ class RequestState:
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
return CpuGpuBuffer(
|
||||
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
@ -122,23 +124,32 @@ class RequestState:
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
prompt_len: int,
|
||||
prefill_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
self.extra_data[req_id] = ExtraData(lora_request)
|
||||
|
||||
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
|
||||
# output tokens, if the request is resumed from preemption.
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_tokens.np[req_idx] = prompt_len
|
||||
self.prompt_len[req_idx] = prompt_len
|
||||
prefill_len = len(prefill_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
|
||||
self.num_tokens[req_idx] = prefill_len
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
# TODO(woosuk): Optimize.
|
||||
self.last_sampled_tokens[req_idx].fill_(-1)
|
||||
|
||||
if lora_request is not None:
|
||||
self.lora_ids[req_idx] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_idx] = NO_LORA_ID
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
@ -165,6 +176,7 @@ class RequestState:
|
||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.extra_data.pop(req_id, None)
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
@ -205,9 +217,25 @@ class RequestState:
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.extra_data[req_id].lora_request
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
|
||||
class Param:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@ -227,3 +255,9 @@ class Param:
|
||||
n = x.shape[0]
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
|
||||
@ -42,6 +42,7 @@ from vllm.v1.outputs import (
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
|
||||
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
@ -495,6 +496,8 @@ class Worker(WorkerBase):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
Reference in New Issue
Block a user