Support beam search & parallel generation (#7)

This commit is contained in:
Woosuk Kwon
2023-03-10 09:58:21 -08:00
committed by GitHub
parent 04e5acc08e
commit 1a7eb7da61
16 changed files with 660 additions and 161 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer
@ -25,12 +25,35 @@ class Frontend:
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
# Stop when we see an EOS token.
stop_token_ids.add(self.tokenizer.eos_token_id)
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)

View File

@ -1,10 +1,12 @@
from typing import Dict, List, Tuple
from typing import Dict, List
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
_MAX_NUM_BATCHED_TOKENS = 2048
@ -66,7 +68,7 @@ class Scheduler:
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
@ -74,7 +76,10 @@ class Scheduler:
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
blocks_to_copy[src_block] = dst_block
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _swap_in(
self,
@ -83,9 +88,8 @@ class Scheduler:
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
def _swap_out(
@ -96,16 +100,15 @@ class Scheduler:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.RUNNING:
seq.status = SequenceStatus.SWAPPED
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in.
self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
@ -168,39 +174,45 @@ class Scheduler:
else:
self.pending.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
# 4. Create input data structures.
prompt_tokens: Dict[int, List[int]] = {}
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
input_seq_groups: List[SequenceGroupInputs] = []
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids()
input_tokens[seq_id] = seq.get_token_ids()
else:
generation_tokens[seq_id] = seq.get_token_ids()[-1]
context_lens[seq_id] = seq.get_len()
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
# 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
@ -208,7 +220,7 @@ class Scheduler:
def post_step(
self,
next_tokens: Dict[int, Tuple[int, int]],
seq_outputs: Dict[int, SequenceOutputs],
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
@ -216,25 +228,32 @@ class Scheduler:
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
parent_seq_id, next_token = next_tokens[seq.seq_id]
if seq.seq_id != parent_seq_id:
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(parent_seq_id)
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Append a new token to the sequence.
seq.append([next_token])
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue