Compare commits

...

115 Commits

Author SHA1 Message Date
1c5c866559 uint64 2025-10-30 16:54:10 -07:00
5c8049d990 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-10-30 16:40:09 -07:00
5666a25efb fix 2025-10-30 16:38:16 -07:00
09e4b2f6eb update 2025-10-30 16:30:06 -07:00
110770170f Merge branch 'main' into woosuk/model-runner-v2 2025-10-30 22:19:50 +00:00
866eef50ca minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-24 15:29:27 +00:00
ad2cf805ad Merge branch 'main' into woosuk/model-runner-v2 2025-09-24 08:19:25 -07:00
704def253c Merge branch 'main' into woosuk/model-runner-v2 2025-09-23 21:08:15 +00:00
42f99150c1 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-23 09:23:21 -07:00
17c2c106b1 Merge branch 'main' into woosuk/model-runner-v2 2025-09-23 09:22:58 -07:00
72f0a71939 assert
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-21 19:37:18 -07:00
fe5472dc03 Merge branch 'main' into woosuk/model-runner-v2 2025-09-21 18:56:48 -07:00
bc73f674bb compute_logits
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-21 11:26:33 -07:00
631b5b47c1 Merge branch 'main' into woosuk/model-runner-v2 2025-09-21 11:25:18 -07:00
42ffdd9179 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 22:15:07 +00:00
8aee6e97e6 64-bit for gumbel seed
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 11:43:01 +00:00
913b8e9569 Merge branch 'main' into woosuk/model-runner-v2 2025-09-20 11:18:35 +00:00
158a46888e random uuid
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 11:17:45 +00:00
98ef239486 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:55:46 +00:00
a66aa37f40 minor:
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:47:20 +00:00
6f038fc4fb Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 20:30:04 +00:00
010e39ec7d minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 19:07:46 +00:00
396bbe67d3 Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 18:53:18 +00:00
c7f3e84b34 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-19 09:49:40 -07:00
a8e7071924 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-19 08:33:47 -07:00
4be2c66e37 fix
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 09:35:38 +00:00
d30c0d50a6 refactor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:17:53 +00:00
9c75d896a8 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:11:37 +00:00
37478c18cf async output
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:10:42 +00:00
33672774f5 Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 06:52:46 +00:00
0d3de9e082 fix
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 06:50:56 +00:00
b405d78c07 DP sampler
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 06:46:46 +00:00
8af87986aa fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 18:37:30 -07:00
af65838d1f dummy run
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 18:29:18 -07:00
52ca2f517a sample
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:39:43 -07:00
8deedfa42b -inf
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:24:00 -07:00
b9c74487d2 logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:23:02 -07:00
31619ff412 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:38:56 -07:00
d2be62378b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:33:18 -07:00
86dade710d fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:32:00 -07:00
efda08481b minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:31:01 -07:00
82da219ff9 Implement topk_logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:29:38 -07:00
323a05b3c5 update
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:51:36 -07:00
a98eff0762 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:21:30 -07:00
67d8c0c21b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:15:31 -07:00
2bb2cb13f4 revert
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:54:19 -07:00
e171e5bb67 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:53:32 -07:00
8407fa02ed fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:52:23 -07:00
82e591f7eb remove
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:35:25 -07:00
330058f9b8 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:30:29 -07:00
aabfaa08cf fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:14:03 -07:00
bc6463ac97 hash
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:49:52 -07:00
a4962833f9 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:20:37 -07:00
3f50030cc8 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:11:46 -07:00
cbdb47dc01 working
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:10:35 -07:00
92f337faeb minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:44:21 -07:00
9050087250 update
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:37:29 -07:00
c1d83f2bae merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:13:56 -07:00
91510260b2 task
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-16 01:06:10 -07:00
c320a33c59 skip warmup
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
83d11373a4 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
dfc84b11a9 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
9f2becd3e6 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-16 00:16:42 -07:00
e107680d8a wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 21:19:18 +00:00
f1981db101 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:53:58 +00:00
69b17891a3 chunked prefilling
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:41:17 +00:00
67852c1036 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:23:54 +00:00
8b3c13c485 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-15 11:17:54 -07:00
9a6fcca030 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-14 15:56:42 -07:00
633f9f006d Merge branch 'main' into woosuk/input-prep 2025-09-14 08:03:28 -07:00
eb3742c72a fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:40 -07:00
e47bb9970b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:07 -07:00
5c133fc860 reorder
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:17:40 -07:00
caf963f2e9 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:13:08 -07:00
9314a83b56 Merge branch 'main' into woosuk/input-prep 2025-09-14 00:44:56 +00:00
7a50a54390 Merge branch 'main' into woosuk/input-prep 2025-09-13 21:33:54 +00:00
787e59629c wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-08 16:42:26 -07:00
5f95309a6d rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 12:01:45 -07:00
286eeb91e8 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 11:16:37 -07:00
6283995a6c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 21:18:16 -07:00
0c56069c7e merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 16:35:45 -07:00
8e6cb9aa4a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 12:23:02 -07:00
ead95fe5dc merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 10:56:27 -07:00
23eae07ea5 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-04 20:19:22 -07:00
b16e2d9602 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 02:10:48 -07:00
4c2a337e67 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:45:29 -07:00
cc340e26af top_p top_k
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:30:08 -07:00
01bf16ede4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:16:26 -07:00
af7b6c5dd4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 23:50:20 -07:00
62d23b3006 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 21:00:16 -07:00
ba1a58f51b MAX_SPEC_LEN
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:43:25 -07:00
22771e5d83 work
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:41:38 -07:00
c11d1e6781 optimize spec
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 16:40:54 -07:00
e696f78e05 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 13:29:58 -07:00
efcb786d52 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 10:44:36 -07:00
9ee9d0e274 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 15:02:07 -07:00
405578121c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:19:10 -07:00
19c0dfc469 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:08:07 -07:00
e451045a66 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:55:13 -07:00
efba25e21a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:39:15 -07:00
b21393cd98 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:58:08 -07:00
d6d719fb24 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:57:49 -07:00
e570b0a4de merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-27 21:45:11 -07:00
a851aaa0fc simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 09:23:05 -07:00
b1d52734f7 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:55:12 -07:00
65f93694be merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:54:32 -07:00
7b4b72e551 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:49:23 -07:00
da9cd26c78 Merge branch 'main' into woosuk/input-prep 2025-08-24 18:36:33 -07:00
a1e3745150 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:36:18 -07:00
48bca9a109 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-23 11:30:29 -07:00
64c8cced18 rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:48:35 -07:00
79e5eb3643 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:37:43 -07:00
c472982746 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-21 21:40:44 -07:00
699bd7928e Merge branch 'main' into woosuk/input-prep 2025-08-17 19:28:38 -07:00
33a3a26ca5 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-17 14:38:24 -07:00
18 changed files with 2360 additions and 12 deletions

View File

@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m", compilation_config={"level": 0, "cudagraph_mode": "full_decode_only"})
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View File

@ -10,6 +10,7 @@ torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.4.1
apache-tvm-ffi==0.1.0b15

View File

@ -46,6 +46,8 @@ def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
def kernel_warmup(worker: "Worker"):
return
# Deep GEMM warmup
do_deep_gemm_warmup = (
envs.VLLM_USE_DEEP_GEMM

View File

@ -412,6 +412,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._workspace_buffer
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
self._workspace_buffer = workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(

View File

@ -34,6 +34,7 @@ else:
class NewRequestData:
req_id: str
prompt_token_ids: list[int] | None
prefill_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
@ -51,6 +52,7 @@ class NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prefill_token_ids=request._all_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
@ -173,6 +175,7 @@ class SchedulerOutput:
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
preempted_req_ids: set[str]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.

View File

@ -606,6 +606,9 @@ class Scheduler(SchedulerInterface):
)
# Construct the scheduler output.
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids()
@ -635,6 +638,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
@ -720,14 +724,6 @@ class Scheduler(SchedulerInterface):
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
new_token_ids.append(token_ids)
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs:
assert not scheduled_in_prev_step
resumed_req_ids.add(req_id)
if not scheduled_in_prev_step:
all_token_ids[req_id] = req.all_token_ids[
: req.num_computed_tokens + num_tokens
]
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
@ -933,7 +929,8 @@ class Scheduler(SchedulerInterface):
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
for req_index, req_id in enumerate(model_runner_output.req_ids):
num_tokens_scheduled = num_scheduled_tokens[req_id]
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure

View File

@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass, fields
from functools import cached_property
from math import prod
import torch
@ -395,3 +396,10 @@ class KVCacheConfig:
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]
@cached_property
def block_sizes(self) -> list[int]:
return [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in self.kv_cache_groups
]

View File

View File

@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
ModelRunnerOutput,
SamplerOutput,
)
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
copy_stream: torch.cuda.Stream,
):
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream
self.copy_event = torch.cuda.Event()
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
# NOTE(woosuk): We should keep the CPU tensors unfreed, until the copy completes.
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True
)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
self.prompt_logprobs_dict = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code ensures compatibility with OSS vLLM.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids_np = self.sampled_token_ids.numpy()
sampled_token_ids = sampled_token_ids_np.tolist()
for i, tokens in enumerate(sampled_token_ids):
del tokens[self.num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()

View File

@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
SlidingWindowSpec,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(
vllm_config: VllmConfig,
kv_cache_dtype: torch.dtype,
) -> dict[str, KVCacheSpec]:
block_size = vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
assert attn_module.attn_type == AttentionType.DECODER
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
)
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = torch.zeros(
self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: list[list[int]],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: list[list[int]],
# [num_reqs]
overwrite: list[bool],
) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
if do_overwrite:
dst_start_idx = 0
else:
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
# CUDA graphs are disabled.
return {}
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
raise NotImplementedError("Piecewise CUDA graphs are not supported")
if self.compilation_config.level != 0:
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
# Prefill is included.
return None
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.v1.outputs import SamplerOutput
def evenly_split(
n: int,
tp_size: int,
tp_rank: int,
) -> tuple[int, int]:
q = n // tp_size
r = n % tp_size
start = q * tp_rank + min(tp_rank, r)
end = start + q + (1 if tp_rank < r else 0)
return start, end
def pad_and_all_gather(
x: torch.Tensor,
padded_size: int,
) -> torch.Tensor:
n = x.shape[0]
if n != padded_size:
padded_x = torch.empty(
(padded_size, *x.shape[1:]),
dtype=x.dtype,
device=x.device,
)
padded_x[:n] = x
else:
padded_x = x
x = tensor_model_parallel_all_gather(padded_x)
return x
def all_gather_sampler_output(
sampler_output: SamplerOutput,
num_reqs: int,
tp_size: int,
) -> SamplerOutput:
n = (num_reqs + tp_size - 1) // tp_size
sampler_output.sampled_token_ids = pad_and_all_gather(
sampler_output.sampled_token_ids, n)[:num_reqs]
# TODO(woosuk): 3 small all-gathers, could be merged into one.
logprobs_tensors = sampler_output.logprobs_tensors
if logprobs_tensors is not None:
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
logprobs_tensors.logprobs = pad_and_all_gather(
logprobs_tensors.logprobs, n)[:num_reqs]
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
return sampler_output

View File

@ -0,0 +1,257 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numba
import numba.types as types
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.utils import random_uuid
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [num_reqs]
logits_indices: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=None,
logits_indices=logits_indices,
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def _prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs(
idx_mapping: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> None:
_prepare_inputs(
idx_mapping,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
return
@triton.jit
def _combine_last_token_ids_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_token_ids,
query_start_loc,
seq_lens,
prefill_len,
)
return input_ids

View File

@ -0,0 +1,688 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_last_token_ids,
prepare_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.hidden_size = self.model_config.get_hidden_size()
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
self.input_prep_event = torch.cuda.Event()
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
)
def get_supported_tasks(self) -> tuple[str]:
return ("generate",)
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
if self.lora_config:
self.model = self.load_lora_model(
self.model,
self.vllm_config,
self.device,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
"Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load,
)
def get_model(self) -> nn.Module:
return self.model
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = kv_cache_config.block_sizes
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config,
self.vllm_config,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
)
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
*args,
input_batch: InputBatch | None = None,
skip_attn: bool = True,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
if not skip_attn:
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping
)
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.query_start_loc,
input_batch.positions,
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=None,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_batch.attn_metadata = attn_metadata
with self.maybe_dummy_run_with_lora(
self.lora_config, input_batch.num_scheduled_tokens
):
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler.sample(logits, sampling_metadata)
@torch.inference_mode()
def profile_run(self) -> None:
input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
input_batch=input_batch,
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def reset_mm_cache(self) -> None:
pass
@torch.inference_mode()
def capture_model(self) -> int:
if not self.cudagraph_manager.needs_capture():
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info(
"Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def warmup_for_prefill(self) -> None:
# For FlashInfer, we would like to execute a dummy prefill run to trigger JIT compilation.
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups)
)
new_block_ids = tuple([] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_len=len(new_req_data.prompt_token_ids),
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request,
)
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
use_cudagraph: bool,
padded_num_tokens: int | None,
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(
scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get,
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
)
if use_cudagraph:
assert padded_num_tokens is not None
num_tokens_after_padding = padded_num_tokens
else:
num_tokens_after_padding = num_tokens
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
prepare_inputs(
idx_mapping_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
self.input_buffers.input_ids,
self.input_buffers.positions,
self.input_buffers.query_start_loc,
self.input_buffers.seq_lens,
num_tokens,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens.
combine_last_token_ids(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.prefill_len.copy_to_gpu(),
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
sampler_output = self.sampler.sample(logits, sampling_metadata)
return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# No request asks for prompt logprobs.
return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
)
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Just to be safe, clone the input ids.
n = input_batch.num_tokens
# Shift the input ids by one.
token_ids = torch.empty_like(input_batch.input_ids[:n])
token_ids[: n - 1] = input_batch.input_ids[1:n]
# To avoid out-of-bound access, set the last token id to 0.
token_ids[n - 1] = 0
# Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
is_prompt_chunked = seq_lens < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
if not is_prompt_chunked[i]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
torch.relu(token_ids),
hidden_states[:n],
self.model.compute_logits,
)
prompt_logprobs[:, 0].masked_fill_(token_ids < 0, 0)
prompt_token_ids = token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc[i]
end_idx = query_start_loc[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
req_extra_data = self.req_states.extra_data[req_id]
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
def postprocess(
self,
sampler_output: SamplerOutput,
sampling_metadata: SamplingMetadata,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch,
) -> AsyncOutput | ModelRunnerOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids
)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = (
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
)
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None,
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: Any | None = None,
) -> AsyncOutput | ModelRunnerOutput:
assert intermediate_tensors is None
with async_barrier(
self.input_prep_event if self.use_async_scheduling else None
):
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
padded_num_tokens = self.cudagraph_manager.get_cudagraph_size(
scheduler_output
)
use_cudagraph = padded_num_tokens is not None
input_batch = self.prepare_inputs(
scheduler_output,
use_cudagraph,
padded_num_tokens,
)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos
)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Run model.
if use_cudagraph:
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
hidden_states = self.cudagraph_manager.run(padded_num_tokens)
else:
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
):
# Run PyTorch model in eager mode.
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampler_output = self.sample(hidden_states, input_batch, sampling_metadata)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output,
sampling_metadata,
prompt_logprobs_dict,
input_batch,
)
return output

View File

@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import triton
import triton.language as tl
from vllm.config.model import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
def sample_token(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits = logits / temp.view(-1, 1)
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
)
sampled = sampled.to(torch.int64)
return sampled, logits if return_logits else None
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample_token(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
@triton.jit
def _gumbel_sample_kernel(
probs_ptr,
probs_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
temp = tl.load(temp_ptr + req_idx)
if temp == 0.0:
# Greedy sampling. Don't apply gumbel noise.
return
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
q = tl.rand(gumbel_seed, r_offset)
# NOTE(woosuk): This logic makes sure q is not 0.
RMAX = 0.9999999403953552
RMAX_LOG = -5.960464477539063e-08
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
q = -1.0 * q
p = tl.load(
probs_ptr + req_idx * probs_stride + r_offset, mask=r_offset < vocab_size
)
p = p / q
tl.store(
probs_ptr + req_idx * probs_stride + r_offset, p, mask=r_offset < vocab_size
)
def gumbel_sample(
probs: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
) -> torch.Tensor:
num_reqs, vocab_size = probs.shape
_gumbel_sample_kernel[(num_reqs,)](
probs,
probs.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
sampled = probs.argmax(dim=-1)
return sampled
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32)
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations are in float32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
top_p = None
top_k = None
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
self.prefill_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
1,
dtype=torch.int64,
device=device,
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
seeds = self.seeds.np[idx_mapping]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)

View File

@ -42,7 +42,9 @@ from vllm.v1.outputs import (
ModelRunnerOutput,
)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
@ -494,6 +496,8 @@ class Worker(WorkerBase):
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens