Rename variables and methods (#91)

This commit is contained in:
Woosuk Kwon
2023-05-10 00:58:31 -07:00
committed by GitHub
parent ce26e57fd3
commit 8d66a7b6d7
7 changed files with 64 additions and 83 deletions

View File

@ -97,15 +97,15 @@ class BlockSpaceManager:
for seq in seq_group.seqs:
self.block_tables[seq.seq_id] = block_table.copy()
def can_append(self, seq_group: SequenceGroup) -> bool:
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for the new token."""
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
@ -156,7 +156,7 @@ class BlockSpaceManager:
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append().
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks

View File

@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
@ -105,7 +105,7 @@ class Scheduler:
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group):
while not self.block_manager.can_append_slot(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
@ -119,7 +119,7 @@ class Scheduler:
break
else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
@ -143,7 +143,7 @@ class Scheduler:
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
@ -252,7 +252,7 @@ class Scheduler:
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
@ -274,7 +274,7 @@ class Scheduler:
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
@ -283,14 +283,14 @@ class Scheduler:
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
seq_group_metadata_list.append(seq_group_metadata)
# Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
input_seq_groups,
seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
@ -330,7 +330,7 @@ class Scheduler:
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
seq.append_token(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
@ -360,13 +360,13 @@ class Scheduler:
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
def _append_slot(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
ret = self.block_manager.append_slot(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy: