[Experimental] Prefix Caching Support (#1669)
Co-authored-by: DouHappy <2278958187@qq.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -234,7 +238,8 @@ def test_sampler_logits_processors(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
|
||||
sample_probs = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user