Compare commits
2 Commits
main
...
lwilkinson
| Author | SHA1 | Date | |
|---|---|---|---|
| a1e3c09cba | |||
| 90d24dee04 |
@ -321,6 +321,13 @@ def set_forward_context(
|
|||||||
attn_metadata, num_tokens or 0,
|
attn_metadata, num_tokens or 0,
|
||||||
num_tokens_across_dp)
|
num_tokens_across_dp)
|
||||||
|
|
||||||
|
# Convienience: if cudagraph is used, and num_tokens is given, we can just
|
||||||
|
# create a batch descriptor here if not given (there's no harm since if it
|
||||||
|
# doesn't match in the wrapper it'll fall through).
|
||||||
|
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
|
||||||
|
batch_descriptor = batch_descriptor or BatchDescriptor(
|
||||||
|
num_tokens=num_tokens)
|
||||||
|
|
||||||
forward_context = create_forward_context(attn_metadata, vllm_config,
|
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||||
virtual_engine, dp_metadata,
|
virtual_engine, dp_metadata,
|
||||||
cudagraph_runtime_mode,
|
cudagraph_runtime_mode,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
@ -78,6 +78,10 @@ class EagleProposer:
|
|||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE and
|
== CompilationLevel.PIECEWISE and
|
||||||
not self.vllm_config.model_config.enforce_eager)
|
not self.vllm_config.model_config.enforce_eager)
|
||||||
|
self.cudagraph_runtime_mode = (CUDAGraphMode.PIECEWISE
|
||||||
|
if self.use_cuda_graph else
|
||||||
|
CUDAGraphMode.NONE)
|
||||||
|
|
||||||
self.cudagraph_batch_sizes = list(
|
self.cudagraph_batch_sizes = list(
|
||||||
reversed(
|
reversed(
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
@ -212,9 +216,12 @@ class EagleProposer:
|
|||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(
|
||||||
self.vllm_config,
|
per_layer_attn_metadata,
|
||||||
num_tokens=num_input_tokens):
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
|
||||||
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
@ -322,9 +329,12 @@ class EagleProposer:
|
|||||||
input_ids = self.input_ids[:input_batch_size]
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(
|
||||||
self.vllm_config,
|
per_layer_attn_metadata,
|
||||||
num_tokens=input_batch_size):
|
self.vllm_config,
|
||||||
|
num_tokens=input_batch_size,
|
||||||
|
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
|
||||||
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
@ -478,9 +488,12 @@ class EagleProposer:
|
|||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(
|
||||||
self.vllm_config,
|
per_layer_attn_metadata,
|
||||||
num_tokens=num_input_tokens):
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
|
||||||
|
):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
@ -664,9 +677,15 @@ class EagleProposer:
|
|||||||
def dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
use_cudagraphs=True,
|
||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_forward_context(
|
||||||
num_tokens=num_tokens):
|
None,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
cudagraph_runtime_mode=self.cudagraph_runtime_mode \
|
||||||
|
if use_cudagraphs else CUDAGraphMode.NONE,
|
||||||
|
):
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
|||||||
@ -2997,7 +2997,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
self.drafter.dummy_run(num_tokens)
|
# For warmup runs don't use cudagraphs in drafter
|
||||||
|
self.drafter.dummy_run(num_tokens, use_cudagraphs=False)
|
||||||
|
|
||||||
# This is necessary to avoid blocking DP.
|
# This is necessary to avoid blocking DP.
|
||||||
# For dummy runs, we typically skip EPLB since we don't have any real
|
# For dummy runs, we typically skip EPLB since we don't have any real
|
||||||
|
|||||||
Reference in New Issue
Block a user