Compare commits

...

2 Commits

Author SHA1 Message Date
a1e3c09cba wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-09-17 22:41:43 +00:00
90d24dee04 enable piecewise cudagraphs for eagle
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-09-17 20:48:14 +00:00
3 changed files with 40 additions and 13 deletions

View File

@ -321,6 +321,13 @@ def set_forward_context(
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
num_tokens_across_dp) num_tokens_across_dp)
# Convienience: if cudagraph is used, and num_tokens is given, we can just
# create a batch descriptor here if not given (there's no harm since if it
# doesn't match in the wrapper it'll fall through).
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(
num_tokens=num_tokens)
forward_context = create_forward_context(attn_metadata, vllm_config, forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, dp_metadata, virtual_engine, dp_metadata,
cudagraph_runtime_mode, cudagraph_runtime_mode,

View File

@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
@ -78,6 +78,10 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.cudagraph_runtime_mode = (CUDAGraphMode.PIECEWISE
if self.use_cuda_graph else
CUDAGraphMode.NONE)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
@ -212,9 +216,12 @@ class EagleProposer:
inputs_embeds = None inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata, with set_forward_context(
self.vllm_config, per_layer_attn_metadata,
num_tokens=num_input_tokens): self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
@ -322,9 +329,12 @@ class EagleProposer:
input_ids = self.input_ids[:input_batch_size] input_ids = self.input_ids[:input_batch_size]
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(
self.vllm_config, per_layer_attn_metadata,
num_tokens=input_batch_size): self.vllm_config,
num_tokens=input_batch_size,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.positions[:input_batch_size], positions=self.positions[:input_batch_size],
@ -478,9 +488,12 @@ class EagleProposer:
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(
self.vllm_config, per_layer_attn_metadata,
num_tokens=num_input_tokens): self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
@ -664,9 +677,15 @@ class EagleProposer:
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
use_cudagraphs=True,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_forward_context(
num_tokens=num_tokens): None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode \
if use_cudagraphs else CUDAGraphMode.NONE,
):
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]

View File

@ -2997,7 +2997,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) # For warmup runs don't use cudagraphs in drafter
self.drafter.dummy_run(num_tokens, use_cudagraphs=False)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real