Compare commits
7 Commits
tpu_v1
...
tpu_v1_opt
| Author | SHA1 | Date | |
|---|---|---|---|
| 70b4e46e70 | |||
| 5fb9dbe6f6 | |||
| 996b92ccb4 | |||
| 2b0526fa15 | |||
| 7be649256f | |||
| 627efde813 | |||
| c2867d5bc1 |
@ -5,7 +5,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
|
||||
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||
@ -34,6 +34,6 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.8.1 # required for compressed-tensors
|
||||
compressed-tensors == 0.9.1 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
#
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
@ -106,9 +106,17 @@ dnspython==2.7.0
|
||||
docutils==0.16
|
||||
# via awscli
|
||||
einops==0.8.0
|
||||
# via -r requirements-test.in
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
# via vector-quantize-pytorch
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
encodec==0.1.1
|
||||
# via vocos
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastparquet==2024.11.0
|
||||
@ -125,6 +133,8 @@ filelock==3.16.1
|
||||
# triton
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
frozendict==2.4.6
|
||||
# via einx
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
@ -261,6 +272,8 @@ numpy==1.26.4
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
@ -283,6 +296,7 @@ numpy==1.26.4
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -455,6 +469,7 @@ pyyaml==6.0.2
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
# vocos
|
||||
ray[adag]==2.40.0
|
||||
# via -r requirements-test.in
|
||||
redis==5.2.0
|
||||
@ -517,6 +532,7 @@ scipy==1.13.1
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
@ -540,7 +556,9 @@ sqlitedict==2.1.0
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
# via
|
||||
# einx
|
||||
# torch
|
||||
tabledata==1.3.3
|
||||
# via pytablewriter
|
||||
tabulate==0.9.0
|
||||
@ -568,12 +586,21 @@ torch==2.5.1
|
||||
# -r requirements-test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# encodec
|
||||
# lm-eval
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchvision==0.20.1
|
||||
# via timm
|
||||
tqdm==4.66.6
|
||||
@ -584,13 +611,15 @@ tqdm==4.66.6
|
||||
# lm-eval
|
||||
# nltk
|
||||
# peft
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.47.0
|
||||
transformers==4.48.2
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# peft
|
||||
@ -615,6 +644,7 @@ typing-extensions==4.12.2
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
@ -626,6 +656,10 @@ urllib3==2.2.3
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements-test.in
|
||||
vocos==0.1.0
|
||||
# via -r requirements-test.in
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
@ -638,4 +672,4 @@ zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
# setuptools
|
||||
@ -13,13 +13,11 @@ ray[default]
|
||||
# Install torch_xla
|
||||
--pre
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.6.0.dev20241126+cpu
|
||||
torchvision==0.20.0.dev20241126+cpu
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
jaxlib==0.4.36.dev20241122
|
||||
jax==0.4.36.dev20241122
|
||||
torch==2.6.0.dev20241216+cpu
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
@ -57,6 +57,14 @@ class BlockTable:
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
@ -436,3 +436,77 @@ class InputBatch:
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return len(self.prompt_logprob_reqs) == 0
|
||||
|
||||
|
||||
def swap_positions(b: InputBatch, id_1, id_2):
|
||||
assert id_1 != id_2
|
||||
req_id_1 = b.req_ids[id_1]
|
||||
req_id_2 = b.req_ids[id_2]
|
||||
assert req_id_1 is not None
|
||||
assert req_id_2 is not None
|
||||
assert id_1 == b.req_id_to_index[req_id_1]
|
||||
assert id_2 == b.req_id_to_index[req_id_2]
|
||||
|
||||
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
|
||||
b.req_id_to_index[req_id_1], b.req_id_to_index[
|
||||
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
|
||||
|
||||
ids = [id_1, id_2]
|
||||
rev_ids = [id_2, id_1]
|
||||
b.num_tokens[ids] = b.num_tokens[rev_ids]
|
||||
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
|
||||
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
|
||||
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
|
||||
|
||||
b.block_table.swap_row(id_1, id_2)
|
||||
|
||||
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
|
||||
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
|
||||
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
|
||||
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
|
||||
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
|
||||
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
|
||||
|
||||
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
|
||||
id_1]
|
||||
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
|
||||
id_2], b.stop_token_ids[id_1]
|
||||
|
||||
gen_1 = b.generators.pop(id_1, None)
|
||||
gen_2 = b.generators.pop(id_2, None)
|
||||
if gen_1 is not None:
|
||||
b.generators[id_2] = gen_1
|
||||
if gen_2 is not None:
|
||||
b.generators[id_1] = gen_2
|
||||
|
||||
|
||||
def ensure_decodes_first(b: InputBatch):
|
||||
num_reqs = b.num_reqs
|
||||
while True:
|
||||
# Find the first prompt index
|
||||
first_prompt_index = None
|
||||
for i in range(num_reqs):
|
||||
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
|
||||
first_prompt_index = i
|
||||
break
|
||||
if first_prompt_index is None:
|
||||
break
|
||||
|
||||
# Find the last decode index
|
||||
last_decode_index = None
|
||||
for i in reversed(range(num_reqs)):
|
||||
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
|
||||
last_decode_index = i
|
||||
break
|
||||
if last_decode_index is None:
|
||||
break
|
||||
|
||||
# Sanity
|
||||
assert first_prompt_index != last_decode_index
|
||||
|
||||
# Check if done
|
||||
if first_prompt_index > last_decode_index:
|
||||
break
|
||||
|
||||
# Swap
|
||||
swap_positions(b, first_prompt_index, last_decode_index)
|
||||
|
||||
@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
@ -20,8 +21,10 @@ from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
|
||||
ensure_decodes_first)
|
||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@ -31,30 +34,24 @@ logger = init_logger(__name__)
|
||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||
_ENABLE_TOP_P = False
|
||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||
# This can significantly affect the performance if too large.
|
||||
_MAX_NUM_SAMPLES = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptInputData:
|
||||
|
||||
req_ids: List
|
||||
prompt_lens: List
|
||||
input_tokens: List
|
||||
input_positions: List
|
||||
attn_metadata: List
|
||||
|
||||
def zipped(self):
|
||||
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
|
||||
self.input_positions, self.attn_metadata)
|
||||
class PromptDecodeInfo:
|
||||
prompt_req_ids: List[str]
|
||||
decode_req_ids: List[str]
|
||||
prompt_scheduled_tokens: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeInputData:
|
||||
req_ids: List
|
||||
class PromptData:
|
||||
input_tokens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
attn_metadata: PallasMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeData:
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional[PallasMetadata] = None
|
||||
@ -69,266 +66,371 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
|
||||
# KV caches for forward pass
|
||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# Used to initialize positions for the individual prefills
|
||||
self.prefill_input_positions = torch.tensor(range(self.max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32).reshape(
|
||||
1, -1)
|
||||
# Cached torch/numpy tensors
|
||||
self.num_swaps = 2
|
||||
self.cur_swap_id = 0
|
||||
self.input_ids_cpu = []
|
||||
self.input_ids_np = []
|
||||
self.input_positions_cpu = []
|
||||
self.input_positions_np = []
|
||||
self.slot_mapping_cpu = []
|
||||
self.slot_mapping_np = []
|
||||
self.prompt_context_lens_cpu = []
|
||||
self.prompt_effective_query_lens_cpu = []
|
||||
self.decode_context_lens_cpu = []
|
||||
self.decode_context_lens_np = []
|
||||
for _ in range(self.num_swaps):
|
||||
self.input_ids_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
|
||||
|
||||
def _prepare_prompt_inputs(
|
||||
self.input_positions_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_positions_np.append(
|
||||
self.input_positions_cpu[-1].numpy())
|
||||
|
||||
self.slot_mapping_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu"))
|
||||
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
|
||||
|
||||
self.prompt_context_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
self.prompt_effective_query_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
|
||||
self.decode_context_lens_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.decode_context_lens_np.append(
|
||||
self.decode_context_lens_cpu[-1].numpy())
|
||||
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
|
||||
def swap_step(self):
|
||||
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
|
||||
|
||||
def _get_prompts_and_decodes(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> PromptInputData:
|
||||
) -> PromptDecodeInfo:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
req_ids = []
|
||||
prompt_lens = []
|
||||
input_tokens_list = []
|
||||
input_positions_list = []
|
||||
attn_metadata_list = []
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
req_state = self.requests[req_id]
|
||||
# Traverse decodes first
|
||||
decode_req_ids = []
|
||||
for i in range(num_reqs):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||
|
||||
# Detect whether this is a prompt (can be full or chunked)
|
||||
if num_computed_tokens >= num_prompt_tokens:
|
||||
# This is a decode => Skip
|
||||
continue
|
||||
if num_computed_tokens < num_prompt_tokens:
|
||||
# This is prompt
|
||||
break
|
||||
|
||||
# This is a prompt
|
||||
req_ids.append(req_id)
|
||||
# This is decode
|
||||
assert num_scheduled_tokens == 1
|
||||
decode_req_ids.append(req_id)
|
||||
|
||||
# Prompt len
|
||||
prompt_len = num_scheduled_tokens
|
||||
prompt_lens.append(prompt_len)
|
||||
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||
assert padded_prompt_len <= self.max_model_len
|
||||
# Traverse prompts
|
||||
prompt_req_ids = []
|
||||
prompt_scheduled_tokens = []
|
||||
for i in range(len(decode_req_ids), num_reqs):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
|
||||
# Input tokens
|
||||
input_tokens = torch.zeros((1, padded_prompt_len),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_tokens[:, :prompt_len] = torch.from_numpy(
|
||||
self.input_batch.token_ids_cpu[req_index,
|
||||
num_computed_tokens:seq_len])
|
||||
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||
# input_tokens[:, prompt_len:] = 0
|
||||
input_tokens_list.append(input_tokens.to(self.device))
|
||||
# Must be prompt
|
||||
assert num_computed_tokens < num_prompt_tokens
|
||||
|
||||
# Input positions
|
||||
input_positions = torch.zeros((1, padded_prompt_len),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions[:, :
|
||||
prompt_len] = self.prefill_input_positions[:,
|
||||
num_computed_tokens:
|
||||
seq_len]
|
||||
# input_positions[:, prompt_len:] = 0
|
||||
input_positions_list.append(input_positions.to(self.device))
|
||||
prompt_req_ids.append(req_id)
|
||||
prompt_scheduled_tokens.append(num_scheduled_tokens)
|
||||
|
||||
# Slot mapping
|
||||
block_table_cpu_tensor = \
|
||||
self.input_batch.block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu_tensor[req_index,
|
||||
input_positions //
|
||||
self.block_size].reshape(
|
||||
1, -1)
|
||||
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
|
||||
prompt_scheduled_tokens)
|
||||
|
||||
block_offsets = input_positions % self.block_size
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
|
||||
slot_mapping = slot_mapping.long()
|
||||
def _prepare_prompt(self, req_index: int,
|
||||
num_scheduled_tokens: int) -> PromptData:
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
|
||||
req_index]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
|
||||
|
||||
# Block table
|
||||
block_table = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table = block_table_cpu_tensor[req_index].unsqueeze(0)
|
||||
block_table = block_table.to(self.device)
|
||||
# Must be prompt
|
||||
assert num_computed_tokens < num_prompt_tokens
|
||||
|
||||
# Context len
|
||||
context_len = 0
|
||||
if num_computed_tokens > 0:
|
||||
context_len = seq_len
|
||||
context_lens = torch.tensor([context_len],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
# Prompt len
|
||||
prompt_len = num_scheduled_tokens
|
||||
padded_prompt_len = _get_padded_prompt_len(prompt_len)
|
||||
assert padded_prompt_len <= self.max_model_len
|
||||
|
||||
# Effective query len
|
||||
effective_query_lens = torch.tensor([prompt_len],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata_list.append(
|
||||
PallasMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping.to(self.device),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens.to(self.device),
|
||||
effective_query_lens=effective_query_lens.to(self.device),
|
||||
))
|
||||
# DEBUG
|
||||
# print("_prepare_prompt:")
|
||||
# print(" prompt_len = {}".format(prompt_len))
|
||||
# print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||
# print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||
# print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||
# print(" seq_len = {}".format(seq_len))
|
||||
# print(" padded_seq_len = {}".format(padded_seq_len))
|
||||
|
||||
# TODO: Remove this
|
||||
# if num_computed_tokens > 0:
|
||||
# print("-------------------")
|
||||
# print("input_tokens.shape = {}".format(input_tokens.shape))
|
||||
# print("input_positions.shape = {}".format(
|
||||
# input_positions.shape))
|
||||
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
|
||||
# print("block_table.shape = {}".format(block_table.shape))
|
||||
# print("context_lens.shape = {} data = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
# print("effective_query_lens.shape = {} data = {}".format(
|
||||
# effective_query_lens.shape, effective_query_lens))
|
||||
# Input tokens
|
||||
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
||||
req_index, num_computed_tokens:padded_seq_len]
|
||||
input_tokens_cpu[prompt_len:] = 0
|
||||
|
||||
return PromptInputData(
|
||||
req_ids=req_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
input_tokens=input_tokens_list,
|
||||
input_positions=input_positions_list,
|
||||
attn_metadata=attn_metadata_list,
|
||||
# DEBUG
|
||||
# print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(num_computed_tokens,
|
||||
self.arange_np[:padded_prompt_len],
|
||||
out=input_positions_np)
|
||||
input_positions_np[prompt_len:] = 0
|
||||
|
||||
# DEBUG
|
||||
# print(" input_positions_np.shape = {} val = {}".format(
|
||||
# input_positions_np.shape, input_positions_np))
|
||||
|
||||
# Slot mapping
|
||||
block_table_np = \
|
||||
self.input_batch.block_table.get_numpy_array()
|
||||
block_numbers_np = block_table_np[req_index, input_positions_np //
|
||||
self.block_size]
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} val = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
# Block table
|
||||
block_table_cpu = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_table_cpu = block_table_cpu[req_index]
|
||||
|
||||
# DEBUG
|
||||
# print(" block_table_cpu = {}".format(block_table_cpu))
|
||||
|
||||
# Context len
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
|
||||
if num_computed_tokens > 0:
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
|
||||
|
||||
# Effective query len
|
||||
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
|
||||
input_positions = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
block_table = block_table_cpu.reshape(1, -1).to(
|
||||
self.device) if block_table_cpu is not None else None
|
||||
|
||||
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
|
||||
self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu[
|
||||
self.cur_swap_id].to(self.device)
|
||||
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" input_tokens.shape = {} val = {}".format(
|
||||
# input_tokens.shape, input_tokens))
|
||||
# print(" input_positions.shape = {} val = {}".format(
|
||||
# input_positions.shape, input_positions))
|
||||
# print(" slot_mapping.shape = {} val = {}".format(
|
||||
# slot_mapping.shape, slot_mapping))
|
||||
# print(" block_table = {}".format(block_table))
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
# print(" effective_query_lens.shape = {} val = {}".format(
|
||||
# effective_query_lens.shape, effective_query_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=effective_query_lens,
|
||||
)
|
||||
|
||||
def _prepare_decode_inputs(
|
||||
return PromptData(input_tokens, input_positions, attn_metadata)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> DecodeInputData:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
req_ids = []
|
||||
req_indices = []
|
||||
input_tokens = []
|
||||
input_positions = []
|
||||
slot_mapping = []
|
||||
context_lens = []
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_prompt_tokens = len(req_state.prompt_token_ids)
|
||||
|
||||
# Detect whether this is a decode
|
||||
if num_computed_tokens < num_prompt_tokens:
|
||||
# This is a prompt => Skip
|
||||
continue
|
||||
|
||||
# This is a decode
|
||||
req_ids.append(req_id)
|
||||
req_indices.append(req_index)
|
||||
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + num_scheduled_tokens
|
||||
|
||||
# Sanity check decode
|
||||
assert num_scheduled_tokens == 1
|
||||
assert seq_len == req_state.num_tokens
|
||||
|
||||
# Input token
|
||||
input_tokens.append([
|
||||
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
|
||||
])
|
||||
|
||||
# Position
|
||||
input_positions.append([num_computed_tokens])
|
||||
|
||||
# Slot mapping
|
||||
block_number = block_table_cpu_tensor[req_index,
|
||||
num_computed_tokens //
|
||||
self.block_size]
|
||||
block_offset = num_computed_tokens % self.block_size
|
||||
slot_id = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot_id])
|
||||
|
||||
# Context len
|
||||
context_lens.append(seq_len)
|
||||
|
||||
# Compute padding
|
||||
batch_size = len(input_tokens)
|
||||
decode_req_ids: List[str],
|
||||
) -> DecodeData:
|
||||
# Batch size
|
||||
batch_size = len(decode_req_ids)
|
||||
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||
num_padding = padded_batch_size - batch_size
|
||||
assert padded_batch_size <= self.max_model_len
|
||||
|
||||
# Add padding
|
||||
input_tokens.extend([[0]] * num_padding)
|
||||
input_positions.extend([[0]] * num_padding)
|
||||
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
|
||||
context_lens.extend([0] * num_padding)
|
||||
req_indices.extend([0] * num_padding)
|
||||
# Init [0 .. batch_size - 1]
|
||||
req_indices_np = self.arange_np[:padded_batch_size]
|
||||
|
||||
# Create tensors
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables_tensor = block_table_cpu_tensor[req_indices]
|
||||
# DEBUG
|
||||
# print("_prepare_decode:")
|
||||
# print(" batch_size = {}".format(batch_size))
|
||||
# print(" padded_batch_size = {}".format(padded_batch_size))
|
||||
# print(" req_indices_np.shape = {} val = {}".format(
|
||||
# req_indices_np.shape, req_indices_np))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
0,
|
||||
out=input_positions_np)
|
||||
input_positions_np[batch_size:] = 0
|
||||
input_positions_cpu = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
|
||||
# DEBUG
|
||||
# print(" input_positions_cpu.shape = {} data = {}".format(
|
||||
# input_positions_cpu.shape, input_positions_cpu))
|
||||
|
||||
# Input tokens
|
||||
token_indices_np = (
|
||||
input_positions_np +
|
||||
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
|
||||
input_tokens_cpu = self.input_ids_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_np),
|
||||
out=input_tokens_cpu)
|
||||
input_tokens_cpu[batch_size:] = 0
|
||||
|
||||
# DEBUG
|
||||
# print(" token_indices_np.shape = {} val = {}".format(
|
||||
# token_indices_np.shape, token_indices_np))
|
||||
# print(" input_tokens_cpu.shape = {} data = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Slot mapping
|
||||
block_table_indices_np = (
|
||||
req_indices_np * self.max_num_blocks_per_req +
|
||||
input_positions_np // self.block_size)
|
||||
|
||||
# DEBUG
|
||||
# print(
|
||||
# " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
|
||||
# .format(block_table_indices_np.shape, block_table_indices_np,
|
||||
# self.max_num_blocks_per_req))
|
||||
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
# DEBUG
|
||||
# print(" block_table_cpu.shape = {} data = {}".format(
|
||||
# block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
|
||||
|
||||
block_numbers_np = block_table_cpu.flatten(
|
||||
)[block_table_indices_np].numpy()
|
||||
|
||||
# DEBUG
|
||||
# print(" block_numbers_np.shape = {} data = {}".format(
|
||||
# block_numbers_np.shape, block_numbers_np))
|
||||
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
# DEBUG
|
||||
# print(" block_offsets_np.shape = {} data = {}".format(
|
||||
# block_offsets_np.shape, block_offsets_np))
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
|
||||
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} data = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
block_table_cpu = block_table_cpu[:padded_batch_size]
|
||||
|
||||
# Context lens
|
||||
context_lens_np = self.decode_context_lens_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
1,
|
||||
out=context_lens_np)
|
||||
context_lens_np[batch_size:] = 0
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
|
||||
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].reshape(-1,
|
||||
1).to(self.device)
|
||||
block_table = block_table_cpu.to(self.device)
|
||||
context_lens = self.decode_context_lens_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].to(self.device)
|
||||
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=padded_batch_size,
|
||||
slot_mapping=slot_mapping_tensor.to(self.device),
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_tables_tensor.to(self.device),
|
||||
context_lens=context_lens_tensor.to(self.device),
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=None,
|
||||
)
|
||||
|
||||
return DecodeInputData(
|
||||
req_ids=req_ids,
|
||||
input_tokens=input_tokens_tensor.to(self.device),
|
||||
input_positions=input_positions_tensor.to(self.device),
|
||||
attn_metadata=attn_metadata)
|
||||
return DecodeData(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
@ -338,18 +440,82 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Update cached state
|
||||
self.update_states(scheduler_output)
|
||||
|
||||
# Prepare inputs
|
||||
prompt_data = self._prepare_prompt_inputs(scheduler_output)
|
||||
decode_data = self._prepare_decode_inputs(scheduler_output)
|
||||
# If necessary, swap decodes/prompts to have all decodes on the start
|
||||
ensure_decodes_first(self.input_batch)
|
||||
|
||||
# Prepare prompts/decodes info
|
||||
pd_info = self._get_prompts_and_decodes(scheduler_output)
|
||||
|
||||
# Init
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
sampled_token_ids_list = [0] * num_reqs
|
||||
num_prompts = len(pd_info.prompt_req_ids)
|
||||
num_decodes = len(pd_info.decode_req_ids)
|
||||
decode_data = None
|
||||
sampled_token_ids = [0] * self.input_batch.num_reqs
|
||||
|
||||
# Run each prompt individually
|
||||
is_first = True
|
||||
for i in range(num_prompts):
|
||||
req_id = pd_info.prompt_req_ids[i]
|
||||
req_index = num_decodes + i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
||||
prompt_len = num_scheduled_tokens
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
|
||||
|
||||
# Prepare first prompt
|
||||
if is_first:
|
||||
prompt_data = self._prepare_prompt(req_index,
|
||||
num_scheduled_tokens)
|
||||
is_first = False
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(prompt_data.attn_metadata,
|
||||
self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(prompt_data.input_tokens,
|
||||
prompt_data.input_positions,
|
||||
prompt_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# In parallel to TPU execution, prepare the next iteration
|
||||
if i < num_prompts - 1:
|
||||
# There is next prompt => prepare it
|
||||
prompt_data = self._prepare_prompt(
|
||||
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
|
||||
elif i == num_prompts - 1 and num_decodes > 0:
|
||||
# There is next decode => prepare it
|
||||
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||
|
||||
# Update cached state (if prompt is fully done)
|
||||
if seq_len >= len(req_state.prompt_token_ids):
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
selected_token_ids_cpu = selected_token_ids.cpu()
|
||||
|
||||
# Get output token
|
||||
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
# DEBUG
|
||||
# print(
|
||||
# " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||
# .format(token_id, prompt_len, req_id, req_index,
|
||||
# selected_token_ids_cpu))
|
||||
|
||||
# Add output token to the request
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Run decodes (a single batch)
|
||||
if len(decode_data.req_ids) > 0:
|
||||
# Forward
|
||||
if num_decodes > 0:
|
||||
|
||||
# Prepare decode (if was not yet prepared)
|
||||
if decode_data is None:
|
||||
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(decode_data.attn_metadata,
|
||||
self.vllm_config):
|
||||
assert self.model is not None
|
||||
@ -359,59 +525,31 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
self.kv_caches)
|
||||
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
selected_token_ids_list = selected_token_ids.cpu().tolist()
|
||||
decode_token_ids_cpu = selected_token_ids.cpu()
|
||||
# Convert to list
|
||||
decode_token_ids_list = decode_token_ids_cpu.tolist()
|
||||
|
||||
# Update cached state
|
||||
for i, req_id in enumerate(decode_data.req_ids):
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
# Update cached state for each decode request
|
||||
for i in range(num_decodes):
|
||||
req_id = pd_info.decode_req_ids[i]
|
||||
req_index = i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + 1
|
||||
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
token_id = selected_token_ids_list[i]
|
||||
token_id = decode_token_ids_list[i]
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
sampled_token_ids_list[req_index] = token_id
|
||||
|
||||
# Run each prompt
|
||||
for (req_id, prompt_len, input_tokens, input_positions,
|
||||
attn_metadata) in prompt_data.zipped():
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Forward
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(input_tokens, input_positions,
|
||||
attn_metadata, self.kv_caches)
|
||||
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len >= len(req_state.prompt_token_ids):
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
||||
sampled_token_ids_list[req_index] = token_id
|
||||
|
||||
# Update cached state
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Get req_ids
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
# Create output
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids_list,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=None,
|
||||
logprobs_cpu=None,
|
||||
)
|
||||
@ -551,60 +689,81 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
|
||||
# TODO: Remove the attn_metadata above
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(token_ids, position_ids, None, kv_caches)
|
||||
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
|
||||
logger.info("Compiling the model with different input shapes.")
|
||||
|
||||
# Capture prefill shapes
|
||||
start = time.perf_counter()
|
||||
# Prefill
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||
ExecutionMode.PREFILL)
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" -- batch_size: %d, seq_len: %d", batch_size,
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
|
||||
if seq_len >= self.model_config.max_model_len:
|
||||
break
|
||||
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
|
||||
# Move to next seq_len
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation for prefill shapes is done in %.2f [secs].",
|
||||
end = time.time()
|
||||
logger.info(" -- Compilation for prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Capture decode shapes.
|
||||
# Prefix prefill
|
||||
if self.scheduler_config.enable_chunked_prefill:
|
||||
logger.info("Compiling the model with different input shapes for "
|
||||
"prefix prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if (num_tokens
|
||||
>= self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
end = time.time()
|
||||
logger.info(
|
||||
" -- Compilation for prefix prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Decode
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for decode:")
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||
ExecutionMode.DECODE)
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.DECODE)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d",
|
||||
batch_size, seq_len,
|
||||
self.scheduler_config.max_num_seqs)
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for decode shapes is done in %.2f [secs].",
|
||||
logger.info(" -- Compilation for decode done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
@ -673,7 +832,7 @@ class ModelWrapperV1(nn.Module):
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if attn_metadata is not None:
|
||||
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
@ -710,7 +869,7 @@ class ModelWrapperV1(nn.Module):
|
||||
return argmax_token_ids
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
def _get_padded_prompt_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -74,20 +77,29 @@ class TPUWorker(WorkerBase):
|
||||
def determine_available_memory(self) -> int:
|
||||
assert self.model_runner is not None
|
||||
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||
device=self.device),
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
runner_kv_caches = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner.dummy_run(
|
||||
kv_caches,
|
||||
runner_kv_caches,
|
||||
num_tokens=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
exec_mode=ExecutionMode.PREFILL,
|
||||
|
||||
Reference in New Issue
Block a user