@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -15,12 +14,12 @@ else:
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: np.ndarray
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: np.ndarray
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: np.ndarray
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
# [num_reqs x num_generated_tokens]
|
||||
sampled_token_ranks: list[int]
|
||||
# [num_reqs]
|
||||
# Used for slicing the logprobs in cases like speculative
|
||||
# decoding where the number of generated tokens may be
|
||||
@ -54,9 +53,9 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.cpu().numpy(),
|
||||
self.logprobs.cpu().numpy(),
|
||||
self.selected_token_ranks.cpu().numpy(),
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
cu_num_generated_tokens,
|
||||
)
|
||||
|
||||
@ -130,6 +129,7 @@ class KVConnectorOutput:
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
# [num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user