Compare commits
1 Commits
use-uv-pyt
...
low_latenc
| Author | SHA1 | Date | |
|---|---|---|---|
| 79acf80471 |
@ -10,12 +10,12 @@ prompts = [
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model="facebook/opt-125m", disable_cascade_attn=True)
|
||||||
# Generate texts from the prompts.
|
# Generate texts from the prompts.
|
||||||
# The output is a list of RequestOutput objects
|
# The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
|
|||||||
@ -85,6 +85,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||||
|
VLLM_ENABLE_V1_ADVANCE_STEP: bool = False
|
||||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||||
Q_SCALE_CONSTANT: int = 200
|
Q_SCALE_CONSTANT: int = 200
|
||||||
@ -600,6 +601,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||||
"VLLM_DISABLE_COMPILE_CACHE":
|
"VLLM_DISABLE_COMPILE_CACHE":
|
||||||
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
||||||
|
"VLLM_ENABLE_V1_ADVANCE_STEP":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_ADVANCE_STEP", "0"))),
|
||||||
|
|
||||||
# If set, vllm will run in development mode, which will enable
|
# If set, vllm will run in development mode, which will enable
|
||||||
# some additional endpoints for developing and debugging,
|
# some additional endpoints for developing and debugging,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -36,6 +37,9 @@ class BlockTable:
|
|||||||
self.block_table_np = self.block_table_cpu.numpy()
|
self.block_table_np = self.block_table_cpu.numpy()
|
||||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
|
self.prev_num_reqs = 0
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def append_row(
|
def append_row(
|
||||||
self,
|
self,
|
||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
@ -48,16 +52,22 @@ class BlockTable:
|
|||||||
self.num_blocks_per_row[row_idx] += num_blocks
|
self.num_blocks_per_row[row_idx] += num_blocks
|
||||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||||
|
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||||
self.num_blocks_per_row[row_idx] = 0
|
self.num_blocks_per_row[row_idx] = 0
|
||||||
self.append_row(block_ids, row_idx)
|
self.append_row(block_ids, row_idx)
|
||||||
|
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def move_row(self, src: int, tgt: int) -> None:
|
def move_row(self, src: int, tgt: int) -> None:
|
||||||
num_blocks = self.num_blocks_per_row[src]
|
num_blocks = self.num_blocks_per_row[src]
|
||||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||||
src, :num_blocks]
|
src, :num_blocks]
|
||||||
self.num_blocks_per_row[tgt] = num_blocks
|
self.num_blocks_per_row[tgt] = num_blocks
|
||||||
|
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def swap_row(self, src: int, tgt: int) -> None:
|
def swap_row(self, src: int, tgt: int) -> None:
|
||||||
num_blocks_src = self.num_blocks_per_row[src]
|
num_blocks_src = self.num_blocks_per_row[src]
|
||||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||||
@ -66,14 +76,28 @@ class BlockTable:
|
|||||||
|
|
||||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||||
|
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def commit(self, num_reqs: int) -> None:
|
def commit(self, num_reqs: int) -> None:
|
||||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
|
||||||
non_blocking=True)
|
# Incremental copy
|
||||||
|
if self.prev_num_reqs != num_reqs or self.is_updated:
|
||||||
|
self.block_table[:num_reqs].copy_(
|
||||||
|
self.block_table_cpu[:num_reqs], non_blocking=True)
|
||||||
|
|
||||||
|
self.prev_num_reqs = num_reqs
|
||||||
|
self.is_updated = False
|
||||||
|
else:
|
||||||
|
# Always copy
|
||||||
|
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.block_table.fill_(0)
|
self.block_table.fill_(0)
|
||||||
self.block_table_cpu.fill_(0)
|
self.block_table_cpu.fill_(0)
|
||||||
|
|
||||||
|
self.is_updated = True
|
||||||
|
|
||||||
def get_device_tensor(self) -> torch.Tensor:
|
def get_device_tensor(self) -> torch.Tensor:
|
||||||
"""Ruturns the device tensor of the block table."""
|
"""Ruturns the device tensor of the block table."""
|
||||||
return self.block_table
|
return self.block_table
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, VllmConfig,
|
||||||
@ -142,6 +143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
weakref.proxy(self))
|
weakref.proxy(self))
|
||||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||||
|
|
||||||
|
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
|
||||||
|
logger.info("Advance_step is enabled")
|
||||||
|
if self.cascade_attn_enabled:
|
||||||
|
logger.warning(
|
||||||
|
"Disabling cascade attn (since advance_step is on)")
|
||||||
|
self.cascade_attn_enabled = False
|
||||||
|
else:
|
||||||
|
logger.info("Advance_step is disabled")
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = model_config.uses_mrope
|
||||||
@ -271,16 +281,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
self.slot_mapping_gpu = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||||
|
self.query_start_loc_gpu = torch.zeros(self.max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
self.seq_lens_gpu = torch.zeros(self.max_num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
# Cached
|
||||||
|
self.prev_num_reqs = 0
|
||||||
|
self.req_indices_gpu = torch.arange(self.max_num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
self.req_indices_block_table_offsets_gpu = (
|
||||||
|
self.req_indices_gpu * self.max_num_blocks_per_req)
|
||||||
|
|
||||||
|
self.num_scheduled_tokens_gpu = torch.ones(self.max_num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
self.cu_num_tokens_gpu = torch.cumsum(self.num_scheduled_tokens_gpu, 0)
|
||||||
|
|
||||||
|
self.query_start_loc_gpu[0] = 0
|
||||||
|
self.query_start_loc_gpu[1:self.max_num_reqs +
|
||||||
|
1] = self.cu_num_tokens_gpu
|
||||||
|
|
||||||
|
self.logits_indices_gpu = self.query_start_loc_gpu[1:] - 1
|
||||||
|
|
||||||
|
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
self.prev_attn_metadata = None
|
||||||
|
self.is_first_advance_decode = True
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
@ -485,6 +530,119 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if batch_changed or batch_reordered:
|
if batch_changed or batch_reordered:
|
||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
|
def _advance_decode_step(
|
||||||
|
self,
|
||||||
|
scheduler_output,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
):
|
||||||
|
# print(" -- inside advance_decode_step")
|
||||||
|
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
assert num_reqs > 0
|
||||||
|
|
||||||
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
assert total_num_scheduled_tokens == num_reqs
|
||||||
|
|
||||||
|
# TODO: Add if needed
|
||||||
|
# Get request indices.
|
||||||
|
# E.g., num_reqs == 3 -> [0, 1, 2]
|
||||||
|
# req_indices_gpu = self.req_indices_gpu[:num_reqs]
|
||||||
|
# Get cu_sums
|
||||||
|
# cu_num_tokens = self.cu_num_tokens_gpu[:num_reqs]
|
||||||
|
|
||||||
|
# Increment positions
|
||||||
|
positions_gpu = self.positions[:total_num_scheduled_tokens]
|
||||||
|
positions_gpu[:total_num_scheduled_tokens] += 1
|
||||||
|
|
||||||
|
# TODO: Verify MROPE is ok here
|
||||||
|
# Calculate M-RoPE positions.
|
||||||
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
|
if self.uses_mrope:
|
||||||
|
self._calc_mrope_positions(scheduler_output)
|
||||||
|
|
||||||
|
# Set next tokens
|
||||||
|
# (prev iteration tokens are cached in prev_sampled_token_ids tensor)
|
||||||
|
assert self.prev_sampled_token_ids is not None
|
||||||
|
self.input_ids[:total_num_scheduled_tokens] = \
|
||||||
|
self.prev_sampled_token_ids[:,0]
|
||||||
|
|
||||||
|
# Calculate the slot mapping
|
||||||
|
block_table_indices_gpu = (
|
||||||
|
self.req_indices_block_table_offsets_gpu[:num_reqs] +
|
||||||
|
positions_gpu // self.block_size)
|
||||||
|
block_table_gpu = self.input_batch.block_table.get_device_tensor()
|
||||||
|
# Note: The block table tensor is async copied from CPU to GPU
|
||||||
|
# (inside the .commit() call) if was previously modified
|
||||||
|
block_numbers_gpu = block_table_gpu.flatten()[block_table_indices_gpu]
|
||||||
|
|
||||||
|
block_offsets_gpu = positions_gpu % self.block_size
|
||||||
|
|
||||||
|
slot_mapping_gpu = self.slot_mapping_gpu[:total_num_scheduled_tokens]
|
||||||
|
slot_mapping_gpu[:] = (block_numbers_gpu * self.block_size +
|
||||||
|
block_offsets_gpu)
|
||||||
|
|
||||||
|
# Prepare the attention metadata.
|
||||||
|
|
||||||
|
# query_start_loc is always the same for all decode iterations
|
||||||
|
query_start_loc_gpu = self.query_start_loc_gpu[:num_reqs + 1]
|
||||||
|
|
||||||
|
if self.uses_mrope:
|
||||||
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||||
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
|
# TODO: Add cascade attn support
|
||||||
|
# Verify cascade attention is disabled
|
||||||
|
assert not self.cascade_attn_enabled
|
||||||
|
|
||||||
|
# TODO: Add support for other attn backends
|
||||||
|
assert self.prev_attn_metadata is not None
|
||||||
|
assert isinstance(self.prev_attn_metadata, FlashAttentionMetadata)
|
||||||
|
|
||||||
|
attn_metadata = self.prev_attn_metadata
|
||||||
|
attn_metadata.max_seq_len += 1
|
||||||
|
attn_metadata.query_start_loc = query_start_loc_gpu
|
||||||
|
attn_metadata.seq_lens += 1
|
||||||
|
attn_metadata.slot_mapping = slot_mapping_gpu
|
||||||
|
|
||||||
|
# print("attn_metadata.seq_lens: shape = {} data = {}".format(
|
||||||
|
# attn_metadata.seq_lens.shape, attn_metadata.seq_lens))
|
||||||
|
|
||||||
|
use_spec_decode = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
|
if not use_spec_decode:
|
||||||
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||||
|
# partial requests. While we should not sample any token
|
||||||
|
# from these partial requests, we do so for simplicity.
|
||||||
|
# We will ignore the sampled tokens from the partial requests.
|
||||||
|
# TODO: Support prompt logprobs.
|
||||||
|
logits_indices = self.logits_indices_gpu[:num_reqs]
|
||||||
|
spec_decode_metadata = None
|
||||||
|
else:
|
||||||
|
# TODO: Check if spec_decode can be enabled here
|
||||||
|
raise Exception("advance_step has no support for spec_decode yet")
|
||||||
|
# # Get the number of draft tokens for each request.
|
||||||
|
# # Iterate over the dictionary rather than all requests since
|
||||||
|
# # not all requests have draft tokens.
|
||||||
|
# num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||||
|
# for req_id, draft_token_ids in (
|
||||||
|
# scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||||
|
# req_idx = self.input_batch.req_id_to_index[req_id]
|
||||||
|
# num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||||
|
|
||||||
|
# spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||||
|
# num_draft_tokens, cu_num_tokens)
|
||||||
|
# logits_indices = spec_decode_metadata.logits_indices
|
||||||
|
|
||||||
|
# Hot-Swap lora model
|
||||||
|
if self.lora_config:
|
||||||
|
# TODO: Check if this works
|
||||||
|
raise Exception("advance_step has no LORA support yet")
|
||||||
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
|
|
||||||
|
return attn_metadata, logits_indices, spec_decode_metadata
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -505,6 +663,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = max(tokens)
|
max_num_scheduled_tokens = max(tokens)
|
||||||
|
|
||||||
|
# Determine if advance step can be used
|
||||||
|
use_spec_decode = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
|
|
||||||
|
is_flash_attn = self.prev_attn_metadata is not None and isinstance(
|
||||||
|
self.prev_attn_metadata, FlashAttentionMetadata)
|
||||||
|
|
||||||
|
is_advance_decode = (envs.VLLM_ENABLE_V1_ADVANCE_STEP
|
||||||
|
and self.prev_num_reqs == num_reqs
|
||||||
|
and max_num_scheduled_tokens == 1
|
||||||
|
and not use_spec_decode
|
||||||
|
and not self.cascade_attn_enabled
|
||||||
|
and is_flash_attn)
|
||||||
|
|
||||||
|
if is_advance_decode:
|
||||||
|
if self.is_first_advance_decode:
|
||||||
|
# The first time advance_step can be used,
|
||||||
|
# we run the usual prepare, so that positions tensor
|
||||||
|
# is initialized
|
||||||
|
self.is_first_advance_decode = False
|
||||||
|
else:
|
||||||
|
# This is the fast-path advance_step
|
||||||
|
# (all tensors are on the GPU and are updated on the GPU)
|
||||||
|
(attn_metadata, logits_indices,
|
||||||
|
spec_decode_metadata) = self._advance_decode_step(
|
||||||
|
scheduler_output, num_scheduled_tokens)
|
||||||
|
return attn_metadata, logits_indices, spec_decode_metadata
|
||||||
|
else:
|
||||||
|
self.is_first_advance_decode = True
|
||||||
|
|
||||||
|
self.prev_num_reqs = num_reqs
|
||||||
|
|
||||||
# Get request indices.
|
# Get request indices.
|
||||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||||
@ -523,6 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||||
|
|
||||||
# Get positions.
|
# Get positions.
|
||||||
|
|
||||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||||
arange,
|
arange,
|
||||||
@ -599,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
)
|
)
|
||||||
|
self.prev_attn_metadata = attn_metadata
|
||||||
|
|
||||||
use_spec_decode = len(
|
use_spec_decode = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
@ -1177,6 +1369,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Get the valid generated tokens.
|
# Get the valid generated tokens.
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
|
self.prev_sampled_token_ids = sampled_token_ids
|
||||||
|
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
# No spec decode tokens.
|
# No spec decode tokens.
|
||||||
|
|||||||
Reference in New Issue
Block a user