Compare commits
1 Commits
amd_dev
...
bind_kv_ca
| Author | SHA1 | Date | |
|---|---|---|---|
| bfff9bcd1d |
@ -10,10 +10,14 @@ prompts = [
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m")
|
# llm = LLM(model="facebook/opt-125m")
|
||||||
|
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_model_len=128,
|
||||||
|
enforce_eager=True)
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@ -21,4 +25,4 @@ outputs = llm.generate(prompts, sampling_params)
|
|||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import torch.distributed as dist
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
@ -33,14 +34,16 @@ class DPMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
# copy from vllm_config.compilation_config.static_forward_context
|
# Copy from vllm_config.compilation_config.static_forward_context
|
||||||
no_compile_layers: dict[str, Any]
|
no_compile_layers: dict[str, Any]
|
||||||
# TODO: extend to support per-layer dynamic forward context
|
# TODO: Extend to support per-layer dynamic forward context
|
||||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
# TODO: Remove after making all virtual_engines share the same kv cache
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
# set dynamically for each forward pass
|
# Set dynamically for each forward pass
|
||||||
dp_metadata: Optional[DPMetadata] = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
|
# Whether this is a profile run (before KV cache init)
|
||||||
|
is_profile_run: bool = False,
|
||||||
|
|
||||||
|
|
||||||
_forward_context: Optional[ForwardContext] = None
|
_forward_context: Optional[ForwardContext] = None
|
||||||
@ -58,7 +61,8 @@ def get_forward_context() -> ForwardContext:
|
|||||||
def set_forward_context(attn_metadata: Any,
|
def set_forward_context(attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
num_tokens: int = 0):
|
num_tokens: int = 0,
|
||||||
|
is_profile_run: bool = False):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
Here we can inject common logic for every model forward pass.
|
Here we can inject common logic for every model forward pass.
|
||||||
@ -93,12 +97,15 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
|
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
|
|
||||||
_forward_context = ForwardContext(
|
_forward_context = ForwardContext(
|
||||||
no_compile_layers=vllm_config.compilation_config.
|
no_compile_layers=vllm_config.compilation_config.
|
||||||
static_forward_context,
|
static_forward_context,
|
||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
dp_metadata=dp_metadata)
|
dp_metadata=dp_metadata,
|
||||||
|
is_profile_run=is_profile_run)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -111,10 +118,17 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
else:
|
else:
|
||||||
# for v1 attention backends
|
# for v1 attention backends
|
||||||
batchsize = attn_metadata.num_input_tokens
|
batchsize = attn_metadata.num_input_tokens
|
||||||
|
|
||||||
# we use synchronous scheduling right now,
|
# we use synchronous scheduling right now,
|
||||||
# adding a sync point here should not affect
|
# adding a sync point here should not affect
|
||||||
# scheduling of the next batch
|
# scheduling of the next batch
|
||||||
torch.cuda.synchronize()
|
if current_platform.is_tpu():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
else:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
# time measurement is in milliseconds
|
# time measurement is in milliseconds
|
||||||
batchsize_forward_time[batchsize].append(
|
batchsize_forward_time[batchsize].append(
|
||||||
|
|||||||
@ -30,7 +30,6 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||||
from vllm.v1.utils import bind_kv_cache
|
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -104,9 +103,6 @@ class TPUModelRunner:
|
|||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
self.encoder_cache_size = encoder_cache_size
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
# Lazy initialization
|
|
||||||
# self.model: nn.Module # Set after load_model
|
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
@ -582,7 +578,6 @@ class TPUModelRunner:
|
|||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.position_ids,
|
positions=self.position_ids,
|
||||||
kv_caches=self.kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
||||||
@ -680,8 +675,8 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
kv_caches,
|
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
is_profile_run: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
@ -728,15 +723,28 @@ class TPUModelRunner:
|
|||||||
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
0,
|
||||||
|
is_profile_run=is_profile_run):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=position_ids,
|
positions=position_ids,
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This is used before KV cache init
|
||||||
|
def profile_run(self, num_tokens) -> None:
|
||||||
|
self._dummy_run(num_tokens=num_tokens, is_profile_run=True)
|
||||||
|
|
||||||
|
# This is used after KV cache init
|
||||||
|
def dummy_run(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
self._dummy_run(num_tokens=num_tokens, is_profile_run=False)
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
|
|
||||||
@ -745,7 +753,7 @@ class TPUModelRunner:
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
num_tokens = 16
|
num_tokens = 16
|
||||||
while True:
|
while True:
|
||||||
self._dummy_run(self.kv_caches, num_tokens)
|
self.dummy_run(num_tokens)
|
||||||
logger.info(" -- num_tokens: %d", num_tokens)
|
logger.info(" -- num_tokens: %d", num_tokens)
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
@ -769,6 +777,7 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
kv_caches: dict[str, torch.Tensor] = {}
|
kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
kv_cache_shape_prev = None
|
||||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||||
tensor_config = kv_cache_config.tensors[layer_name]
|
tensor_config = kv_cache_config.tensors[layer_name]
|
||||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||||
@ -779,6 +788,12 @@ class TPUModelRunner:
|
|||||||
layer_spec.head_size)
|
layer_spec.head_size)
|
||||||
dtype = layer_spec.dtype
|
dtype = layer_spec.dtype
|
||||||
|
|
||||||
|
# Ensure all "kv_cache_shape" are the same across the model
|
||||||
|
if kv_cache_shape_prev is None:
|
||||||
|
kv_cache_shape_prev = kv_cache_shape
|
||||||
|
else:
|
||||||
|
assert kv_cache_shape == kv_cache_shape_prev
|
||||||
|
|
||||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@ -788,10 +803,16 @@ class TPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
bind_kv_cache(
|
# ModelWrapperV1 needs to know the KV cache shape
|
||||||
kv_caches,
|
self.model.set_kv_cache_shape(kv_cache_shape_prev)
|
||||||
self.vllm_config.compilation_config.static_forward_context,
|
|
||||||
self.kv_caches)
|
# Associates each attention layer in the `forward_context` with the
|
||||||
|
# initialized KV cache.
|
||||||
|
forward_context = self.vllm_config.compilation_config \
|
||||||
|
.static_forward_context
|
||||||
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
|
# NOTE: Use list because of v0 PP virtual engine.
|
||||||
|
forward_context[layer_name].kv_cache = [kv_cache]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapperV1(nn.Module):
|
class ModelWrapperV1(nn.Module):
|
||||||
@ -799,12 +820,15 @@ class ModelWrapperV1(nn.Module):
|
|||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.kv_cache_shape = None
|
||||||
|
|
||||||
|
def set_kv_cache_shape(self, kv_cache_shape):
|
||||||
|
self.kv_cache_shape = kv_cache_shape
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
@ -817,16 +841,20 @@ class ModelWrapperV1(nn.Module):
|
|||||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||||
hidden_size]. It is used for multimodal models.
|
hidden_size]. It is used for multimodal models.
|
||||||
"""
|
"""
|
||||||
# Skip this in memory profiling at initialization.
|
forward_context = get_forward_context()
|
||||||
if kv_caches[0][0].numel() > 0:
|
attn_metadata = forward_context.attn_metadata
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
# work, we need to flatten the first three dimensions and modify
|
# work, we need to flatten the first three dimensions and modify
|
||||||
# the slot_mapping accordingly.
|
# the slot_mapping accordingly.
|
||||||
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
|
#
|
||||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
# Note: We skip this step during first profiling run (before KV init)
|
||||||
|
if not forward_context.is_profile_run:
|
||||||
|
assert self.kv_cache_shape # Ensure initialized
|
||||||
|
num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape
|
||||||
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
slot_mapping = slot_mapping.flatten()
|
slot_mapping = slot_mapping.flatten()
|
||||||
head_indicies = torch.arange(0,
|
head_indicies = torch.arange(0,
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from vllm.v1.core.scheduler import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.utils import bind_kv_cache
|
|
||||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -128,18 +127,19 @@ class TPUWorker:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
runner_kv_caches: list[torch.Tensor] = []
|
# Associates each attention layer in the `forward_context` with the
|
||||||
bind_kv_cache(
|
# initialized KV cache.
|
||||||
kv_caches,
|
forward_context = self.vllm_config.compilation_config \
|
||||||
self.vllm_config.compilation_config.static_forward_context,
|
.static_forward_context
|
||||||
runner_kv_caches)
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
|
# NOTE: Use list because of v0 PP virtual engine.
|
||||||
|
forward_context[layer_name].kv_cache = [kv_cache]
|
||||||
|
|
||||||
self.model_runner._dummy_run(
|
self.model_runner.profile_run(
|
||||||
runner_kv_caches,
|
num_tokens=self.scheduler_config.max_num_batched_tokens)
|
||||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
# Get the maximum amount of memory used by the model weights and
|
# Get the maximum amount of memory used by the model weights and
|
||||||
|
|||||||
Reference in New Issue
Block a user