Files
vllm/examples/offline_inference/data_parallel.py
Sage Moore 4672c72f44 capture works replay does not
Signed-off-by: Sage Moore <sage@neuralmagic.com>
2025-06-28 19:14:48 +00:00

194 lines
6.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Usage:
Single node:
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
--dp-size=2 \
--tp-size=2
Multi-node:
Node 0 (assume the node has ip of 10.99.48.128):
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
--dp-size=2 \
--tp-size=2 \
--node-size=2 \
--node-rank=0 \
--master-addr=10.99.48.128 \
--master-port=13345
Node 1:
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
--dp-size=2 \
--tp-size=2 \
--node-size=2 \
--node-rank=1 \
--master-addr=10.99.48.128 \
--master-port=13345
"""
import os
from time import sleep
from vllm import LLM, EngineArgs, SamplingParams
from vllm.utils import FlexibleArgumentParser, get_open_port
def parse_args():
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="ibm-research/PowerMoE-3b")
parser.add_argument("--dp-size",
type=int,
default=2,
help="Data parallel size")
parser.add_argument("--tp-size",
type=int,
default=2,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
return parser.parse_args()
def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
# engine processes.
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 5
# import random
# import string
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
floor = len(prompts) // dp_size
remainder = len(prompts) % dp_size
# Distribute prompts into even groups.
def start(rank):
return rank * floor + min(rank, remainder)
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts = ["Placeholder"]
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
# Create a sampling params object.
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
)
# Fixed params
args.pop("tensor_parallel_size")
args.pop("enable_expert_parallel")
# Create an LLM.
llm = LLM(
tensor_parallel_size=GPUs_per_dp_rank,
enable_expert_parallel=True,
**args,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
if i >= 5:
# print only 5 outputs
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}"
)
# Give engines time to pause their processing loops before exiting.
sleep(1)
if __name__ == "__main__":
args = vars(parse_args())
dp_size = args.pop("dp_size")
tp_size = args.pop("tp_size")
node_size = args.pop("node_size")
node_rank = args.pop("node_rank")
if node_size == 1:
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
args.pop("master_addr")
args.pop("master_port")
else:
dp_master_ip = args.pop("master_addr")
dp_master_port = args.pop("master_port")
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
dp_per_node = dp_size // node_size
from multiprocessing import Process
procs = []
for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
proc = Process(target=main,
args=(
args,
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
tp_size,
))
proc.start()
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join(timeout=1200)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
exit_code = proc.exitcode
exit(exit_code)