[Experimental] Prefix Caching Support (#1669)

Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
shiyi.c_98
2024-01-17 16:32:10 -08:00
committed by GitHub
parent 14cc317ba4
commit d10f8e1d43
20 changed files with 1356 additions and 71 deletions

View File

@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -234,7 +238,8 @@ def test_sampler_logits_processors(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sample_probs = None