[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)

This commit is contained in:
SangBin Cho
2024-04-26 22:02:02 +09:00
committed by GitHub
parent a88081bf76
commit 603ad84815
18 changed files with 859 additions and 630 deletions

View File

@ -6,6 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,