OpenAI Compatible Frontend (#116)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from cacheflow.block import LogicalTokenBlock
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
SWAPPED = enum.auto()
|
||||
FINISHED = enum.auto()
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "SequenceStatus") -> bool:
|
||||
return status in [
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
||||
if status == SequenceStatus.FINISHED_STOPPED:
|
||||
finish_reason = "stop"
|
||||
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
return finish_reason
|
||||
|
||||
class SequenceData:
|
||||
|
||||
@ -20,7 +37,6 @@ class SequenceData:
|
||||
prompt_token_ids: List[int],
|
||||
) -> None:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprob = 0.0
|
||||
|
||||
@ -166,7 +182,7 @@ class SequenceGroup:
|
||||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
|
||||
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
|
||||
Reference in New Issue
Block a user