Implement presence and frequency penalties (#95)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupMetadata
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutputs)
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
|
||||
|
||||
@ -72,7 +72,6 @@ class Worker:
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
|
||||
def init_distributed_environment(self,
|
||||
distributed_init_method: str,
|
||||
rank: int,
|
||||
@ -96,7 +95,6 @@ class Worker:
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
@ -107,15 +105,15 @@ class Worker:
|
||||
if not seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
# Use any sequence in the group.
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
@ -141,27 +139,26 @@ class Worker:
|
||||
if seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
for seq_id in seq_ids:
|
||||
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
|
||||
generation_token = seq_group_metadata.input_tokens[seq_id][0]
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
position = seq_group_metadata.context_len - 1
|
||||
context_len = seq_data.get_len()
|
||||
position = context_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(
|
||||
max_context_len, seq_group_metadata.context_len)
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
context_lens.append(seq_group_metadata.context_len)
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
@ -188,9 +185,13 @@ class Worker:
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
|
||||
Reference in New Issue
Block a user