[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
This commit is contained in:
@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
logits_processor_output = logits_processor(
|
||||
embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
|
||||
Reference in New Issue
Block a user