Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-10-30 16:40:09 -07:00
parent 5666a25efb
commit 5c8049d990

View File

@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
if TYPE_CHECKING:
@ -15,12 +14,12 @@ else:
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1]
logprobs: np.ndarray
# [num_reqs]
sampled_token_ranks: np.ndarray
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]]
# [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int]
# [num_reqs]
# Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be
@ -54,9 +53,9 @@ class LogprobsTensors(NamedTuple):
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists(
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
cu_num_generated_tokens,
)
@ -130,6 +129,7 @@ class KVConnectorOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
class ModelRunnerOutput:
# [num_reqs]