[New Model]mBART model (#22883)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
@ -2,9 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically BART
|
||||
encoder/decoder models, specifically BART and mBART.
|
||||
|
||||
This script is refactored to allow model selection via command-line arguments.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
@ -14,119 +19,175 @@ from vllm.inputs import (
|
||||
)
|
||||
|
||||
|
||||
def create_prompts(tokenizer):
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
class ModelRequestData(NamedTuple):
|
||||
"""
|
||||
Holds the configuration for a specific model, including its
|
||||
HuggingFace ID and the prompts to use for the demo.
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
encoder_prompts: list
|
||||
decoder_prompts: list
|
||||
hf_overrides: Optional[dict] = None
|
||||
|
||||
|
||||
def get_bart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/bart-large-cnn.
|
||||
This uses the exact test cases from the original script.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"An encoder prompt",
|
||||
]
|
||||
decoder_prompts = [
|
||||
"A decoder prompt",
|
||||
"Another decoder prompt",
|
||||
]
|
||||
return ModelRequestData(
|
||||
model_id="facebook/bart-large-cnn",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
)
|
||||
|
||||
|
||||
def get_mbart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/mbart-large-en-ro.
|
||||
This uses prompts suitable for an English-to-Romanian translation task.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"How are you today?",
|
||||
]
|
||||
decoder_prompts = ["", ""]
|
||||
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
|
||||
return ModelRequestData(
|
||||
model_id="facebook/mbart-large-en-ro",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
|
||||
|
||||
MODEL_GETTERS = {
|
||||
"bart": get_bart_config,
|
||||
"mbart": get_mbart_config,
|
||||
}
|
||||
|
||||
|
||||
def create_all_prompt_types(
|
||||
encoder_prompts_raw: list,
|
||||
decoder_prompts_raw: list,
|
||||
tokenizer,
|
||||
) -> list:
|
||||
"""
|
||||
Generates a list of diverse prompt types for demonstration.
|
||||
This function is generic and uses the provided raw prompts
|
||||
to create various vLLM input objects.
|
||||
"""
|
||||
text_prompt_raw = encoder_prompts_raw[0]
|
||||
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
|
||||
tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
|
||||
)
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# ruff: noqa: E501
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
prompt_token_ids=tokenizer.encode(
|
||||
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
|
||||
)
|
||||
)
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
decoder_tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
|
||||
)
|
||||
single_prompt_examples = [
|
||||
text_prompt_raw,
|
||||
text_prompt,
|
||||
tokens_prompt,
|
||||
]
|
||||
explicit_pair_examples = [
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt_raw,
|
||||
decoder_prompt=decoder_tokens_prompt,
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt,
|
||||
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=tokens_prompt,
|
||||
decoder_prompt=text_prompt,
|
||||
),
|
||||
]
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
["An encoder prompt", "Another encoder prompt"],
|
||||
["A decoder prompt", "Another decoder prompt"],
|
||||
encoder_prompts_raw,
|
||||
decoder_prompts_raw,
|
||||
)
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw,
|
||||
single_text_prompt,
|
||||
single_tokens_prompt,
|
||||
enc_dec_prompt1,
|
||||
enc_dec_prompt2,
|
||||
enc_dec_prompt3,
|
||||
] + zipped_prompt_list
|
||||
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
|
||||
|
||||
|
||||
# Create a sampling params object.
|
||||
def create_sampling_params():
|
||||
def create_sampling_params() -> SamplingParams:
|
||||
"""Create a sampling params object."""
|
||||
return SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
max_tokens=30,
|
||||
)
|
||||
|
||||
|
||||
# Print the outputs.
|
||||
def print_outputs(outputs):
|
||||
print("-" * 50)
|
||||
def print_outputs(outputs: list):
|
||||
"""Formats and prints the generation outputs."""
|
||||
print("-" * 80)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i + 1}:")
|
||||
print(
|
||||
f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}"
|
||||
print(f"Encoder Prompt: {encoder_prompt!r}")
|
||||
print(f"Decoder Prompt: {prompt!r}")
|
||||
print(f"Generated Text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main execution function."""
|
||||
model_key = args.model
|
||||
if model_key not in MODEL_GETTERS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_key}. "
|
||||
f"Available models: {list(MODEL_GETTERS.keys())}"
|
||||
)
|
||||
print("-" * 50)
|
||||
config_getter = MODEL_GETTERS[model_key]
|
||||
model_config = config_getter()
|
||||
|
||||
|
||||
def main():
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
print(f"🚀 Running demo for model: {model_config.model_id}")
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
model=model_config.model_id,
|
||||
dtype="float",
|
||||
hf_overrides=model_config.hf_overrides,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
|
||||
prompts = create_prompts(tokenizer)
|
||||
prompts = create_all_prompt_types(
|
||||
encoder_prompts_raw=model_config.encoder_prompts,
|
||||
decoder_prompts_raw=model_config.decoder_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
sampling_params = create_sampling_params()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A flexible demo for vLLM encoder-decoder models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
default="bart",
|
||||
choices=MODEL_GETTERS.keys(),
|
||||
help="The short name of the model to run.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user