[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
|
||||
@ -27,10 +28,7 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
"prompt": prompt_str,
|
||||
"prompt_token_ids": prompt_tokens,
|
||||
},
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[prompt],
|
||||
@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
|
||||
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
||||
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
||||
|
||||
inputs = {
|
||||
"prompt": decoder_prompt_str,
|
||||
"prompt_token_ids": decoder_prompt_tokens,
|
||||
"encoder_prompt": encoder_prompt_str,
|
||||
"encoder_prompt_token_ids": encoder_prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
inputs: EncoderDecoderInputs = {
|
||||
"decoder": token_inputs(decoder_prompt_tokens,
|
||||
prompt=decoder_prompt_str),
|
||||
"encoder": token_inputs(encoder_prompt_tokens,
|
||||
prompt=encoder_prompt_str),
|
||||
}
|
||||
|
||||
decoder_prompt = Sequence(int(request_id),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=True)
|
||||
inputs=inputs["decoder"],
|
||||
block_size=block_size)
|
||||
|
||||
encoder_prompt = Sequence(int(request_id),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=False)
|
||||
inputs=inputs["encoder"],
|
||||
block_size=block_size)
|
||||
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[decoder_prompt],
|
||||
sampling_params=SamplingParams(best_of=best_of),
|
||||
@ -108,7 +104,7 @@ def create_seq_group(
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
inputs={"prompt_token_ids": prompt_token_ids},
|
||||
inputs=token_inputs(prompt_token_ids),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"encoder_prompt": "",
|
||||
"encoder_prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
inputs: EncoderDecoderInputs = {
|
||||
"decoder": token_inputs(prompt_token_ids),
|
||||
"encoder": token_inputs(prompt_token_ids),
|
||||
}
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
# Construct decoder input sequences
|
||||
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=True)
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
inputs=inputs["decoder"],
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
seq.append_token_id(
|
||||
@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
|
||||
seqs.append(seq)
|
||||
|
||||
# Encoder input sequence
|
||||
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=False)
|
||||
encoder_seq = Sequence(
|
||||
seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs=inputs["encoder"],
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
return SequenceGroup(request_id=request_id,
|
||||
seqs=seqs,
|
||||
|
||||
Reference in New Issue
Block a user