OpenAI Compatible Frontend (#116)

This commit is contained in:
Zhuohan Li
2023-05-23 21:39:50 -07:00
committed by GitHub
parent e86717833d
commit 057daef778
20 changed files with 644 additions and 169 deletions

View File

@ -1,6 +1,6 @@
import copy
import enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams
@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
FINISHED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
]
@staticmethod
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
if status == SequenceStatus.FINISHED_STOPPED:
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
else:
finish_reason = None
return finish_reason
class SequenceData:
@ -20,7 +37,6 @@ class SequenceData:
prompt_token_ids: List[int],
) -> None:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0
@ -166,7 +182,7 @@ class SequenceGroup:
raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool:
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "