[Feature] Expert Parallelism Load Balancer (EPLB) (#18343)
Signed-off-by: Bowen Wang <abmfy@icloud.com>
This commit is contained in:
292
tests/distributed/test_eplb_algo.py
Normal file
292
tests/distributed/test_eplb_algo.py
Normal file
@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
||||
|
||||
|
||||
def test_basic_rebalance():
|
||||
"""Test basic rebalancing functionality"""
|
||||
# Example from https://github.com/deepseek-ai/eplb
|
||||
weight = torch.tensor([
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
])
|
||||
|
||||
num_layers = weight.shape[0]
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify output shapes
|
||||
assert phy2log.shape == (
|
||||
2,
|
||||
16,
|
||||
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
||||
assert (log2phy.shape[0] == 2
|
||||
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||
assert (
|
||||
log2phy.shape[1] == 12
|
||||
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||
assert logcnt.shape == (
|
||||
2,
|
||||
12,
|
||||
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
||||
|
||||
# Verify physical to logical expert mapping range is correct
|
||||
assert torch.all(phy2log >= 0) and torch.all(
|
||||
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
|
||||
|
||||
# Verify expert count reasonableness
|
||||
assert torch.all(
|
||||
logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||
assert (
|
||||
torch.sum(logcnt, dim=1).sum() == num_replicas *
|
||||
num_layers), f"Total replicas should be {num_replicas * num_layers}"
|
||||
|
||||
# Verify expected output
|
||||
expected_phy2log = torch.tensor([
|
||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||
])
|
||||
assert torch.all(phy2log == expected_phy2log)
|
||||
|
||||
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
|
||||
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
|
||||
assert torch.all(logcnt == expected_logcnt)
|
||||
|
||||
|
||||
def test_single_gpu_case():
|
||||
"""Test single GPU case"""
|
||||
weight = torch.tensor([[10, 20, 30, 40]])
|
||||
num_replicas = 4
|
||||
num_groups = 1
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 4)
|
||||
assert log2phy.shape[0] == 1
|
||||
assert log2phy.shape[1] == 4
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Verify all logical experts are mapped
|
||||
assert set(phy2log[0].tolist()) == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_equal_weights():
|
||||
"""Test case with equal weights"""
|
||||
weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]])
|
||||
num_replicas = 8
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# With equal weights, each expert should have exactly one replica
|
||||
assert torch.all(
|
||||
logcnt == 1
|
||||
), "With equal weights and no replication, " \
|
||||
"each expert should have exactly 1 replica"
|
||||
|
||||
|
||||
def test_extreme_weight_imbalance():
|
||||
"""Test extreme weight imbalance case"""
|
||||
weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]])
|
||||
num_replicas = 12
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 12)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# Expert with highest weight (index 0) should have more replicas
|
||||
assert (
|
||||
logcnt[0, 0]
|
||||
> logcnt[0, 1]), "Expert with highest weight should have more replicas"
|
||||
|
||||
|
||||
def test_multiple_layers():
|
||||
"""Test multiple layers case"""
|
||||
weight = torch.tensor([
|
||||
[10, 20, 30, 40, 50, 60], # First layer
|
||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||
])
|
||||
num_replicas = 8
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (3, 8)
|
||||
assert logcnt.shape == (3, 6)
|
||||
|
||||
# Verify expert allocation is reasonable for each layer
|
||||
for layer in range(3):
|
||||
assert torch.all(phy2log[layer] >= 0) and torch.all(
|
||||
phy2log[layer] < 6
|
||||
), f"Layer {layer} physical to logical mapping" \
|
||||
"should be in range [0, 6)"
|
||||
assert (torch.sum(logcnt[layer]) == num_replicas
|
||||
), f"Layer {layer} total replicas should be {num_replicas}"
|
||||
|
||||
|
||||
def test_parameter_validation():
|
||||
"""Test parameter validation"""
|
||||
weight = torch.tensor([[10, 20, 30, 40]])
|
||||
|
||||
# Test non-divisible case - this should handle normally without throwing
|
||||
# errors because the function will fall back to global load balancing
|
||||
# strategy
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Test cases that will actually cause errors:
|
||||
# num_physical_experts not divisible by num_gpus
|
||||
with pytest.raises(AssertionError):
|
||||
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
|
||||
|
||||
def test_small_scale_hierarchical():
|
||||
"""Test small-scale hierarchical load balancing"""
|
||||
weight = torch.tensor([
|
||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||
])
|
||||
num_replicas = 12
|
||||
num_groups = 4 # 4 groups, 2 experts each
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Verify basic constraints
|
||||
assert phy2log.shape == (1, 12)
|
||||
assert logcnt.shape == (1, 8)
|
||||
assert torch.sum(logcnt) == num_replicas
|
||||
assert torch.all(logcnt >= 1)
|
||||
|
||||
# Expert with highest weight should have more replicas
|
||||
max_weight_expert = torch.argmax(weight[0])
|
||||
assert (logcnt[0, max_weight_expert]
|
||||
>= 2), "Highest weight expert should have multiple replicas"
|
||||
|
||||
|
||||
def test_global_load_balance_fallback():
|
||||
"""Test global load balancing fallback case"""
|
||||
# When num_groups % num_nodes != 0, should fall back to global load
|
||||
# balancing
|
||||
weight = torch.tensor([[10, 20, 30, 40, 50, 60]])
|
||||
num_replicas = 8
|
||||
num_groups = 3 # Cannot be divided evenly by num_nodes=2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Should work normally, just using global load balancing strategy
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 6)
|
||||
assert torch.sum(logcnt) == num_replicas
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_device_compatibility(device):
|
||||
"""Test device compatibility"""
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weight = torch.tensor([[10, 20, 30, 40]], device=device)
|
||||
num_replicas = 6
|
||||
num_groups = 2
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
|
||||
# Function will convert to CPU internally, but should handle different
|
||||
# device inputs normally
|
||||
assert phy2log.shape == (1, 6)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
|
||||
def test_additional_cases():
|
||||
"""Test more edge cases and different parameter combinations"""
|
||||
|
||||
# Test case 1: Large-scale distributed setup
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
|
||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
assert logcnt1.shape == (1, 16)
|
||||
assert torch.sum(logcnt1) == 24
|
||||
|
||||
# Test case 2: Different weight distributions
|
||||
weight2 = torch.tensor([
|
||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
])
|
||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
assert logcnt2.shape == (2, 6)
|
||||
|
||||
# Verify high-weight experts have more replicas
|
||||
for layer in range(2):
|
||||
max_weight_idx = torch.argmax(weight2[layer])
|
||||
assert logcnt2[layer, max_weight_idx] >= 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.tensor([
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
])
|
||||
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
|
||||
num_groups, num_nodes,
|
||||
num_gpus)
|
||||
print(phy2log)
|
||||
|
||||
test_basic_rebalance()
|
||||
504
tests/distributed/test_eplb_execute.py
Normal file
504
tests/distributed/test_eplb_execute.py
Normal file
@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
rearrange_expert_weights_inplace)
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
init_distributed_environment)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes: list[multiprocessing.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ['LOCAL_RANK']
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def create_expert_indices_with_redundancy(
|
||||
num_layers: int,
|
||||
num_logical_experts: int,
|
||||
total_physical_experts: int,
|
||||
redundancy_config: list[int], # redundancy for each logical expert
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create expert indices with redundancy.
|
||||
|
||||
Args:
|
||||
num_layers: number of layers
|
||||
num_logical_experts: number of logical experts
|
||||
total_physical_experts: total number of physical experts
|
||||
redundancy_config: redundancy for each logical expert
|
||||
|
||||
Returns:
|
||||
indices: Shape (num_layers, total_physical_experts)
|
||||
"""
|
||||
assert sum(redundancy_config) == total_physical_experts
|
||||
assert len(redundancy_config) == num_logical_experts
|
||||
|
||||
indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long)
|
||||
|
||||
for layer in range(num_layers):
|
||||
physical_pos = 0
|
||||
for logical_expert_id, redundancy in enumerate(redundancy_config):
|
||||
for _ in range(redundancy):
|
||||
indices[layer, physical_pos] = logical_expert_id
|
||||
physical_pos += 1
|
||||
|
||||
# Shuffle the indices at dim 1
|
||||
for layer in range(num_layers):
|
||||
indices[layer] = indices[layer][torch.randperm(indices.shape[1])]
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
def create_expert_weights(
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
hidden_sizes: list[int],
|
||||
rank: int,
|
||||
device: torch.device,
|
||||
physical_to_logical_mapping: torch.Tensor,
|
||||
) -> list[list[torch.Tensor]]:
|
||||
"""
|
||||
Create fake expert weights tensor for testing.
|
||||
|
||||
Use `arange` to generate predictable weights values, based on logical
|
||||
expert ID.
|
||||
All replicas of the same logical expert should have the same weights.
|
||||
|
||||
Args:
|
||||
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
|
||||
mapping[layer, physical_pos] = logical_expert_id
|
||||
"""
|
||||
expert_weights = []
|
||||
|
||||
for layer in range(num_layers):
|
||||
layer_weights = []
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
weight_tensor = torch.zeros(num_local_experts,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
for local_expert in range(num_local_experts):
|
||||
# Get the logical expert ID for this physical expert
|
||||
global_pos = rank * num_local_experts + local_expert
|
||||
logical_expert_id = physical_to_logical_mapping[
|
||||
layer, global_pos].item()
|
||||
|
||||
# Generate weights based on logical expert ID
|
||||
# (so that all replicas of the same logical expert have the
|
||||
# same weights)
|
||||
base_value = (logical_expert_id * 1000 + layer * 100 +
|
||||
weight_idx * 10)
|
||||
weight_tensor[local_expert] = torch.arange(base_value,
|
||||
base_value +
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
layer_weights.append(weight_tensor)
|
||||
expert_weights.append(layer_weights)
|
||||
|
||||
return expert_weights
|
||||
|
||||
|
||||
def create_redundancy_config(
|
||||
num_logical_experts: int,
|
||||
num_physical_experts: int,
|
||||
) -> list[int]:
|
||||
"""Create a redundancy configuration."""
|
||||
redundancy_config = [1] * num_logical_experts
|
||||
remaining = num_physical_experts - num_logical_experts
|
||||
# Randomly assign the remaining physical experts to the logical experts
|
||||
for _ in range(remaining):
|
||||
redundancy_config[random.choice(range(num_logical_experts))] += 1
|
||||
return redundancy_config
|
||||
|
||||
|
||||
def verify_expert_weights_after_shuffle(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
new_indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
"""Verify the weights after shuffling are correct."""
|
||||
num_layers = len(expert_weights)
|
||||
|
||||
for layer in range(num_layers):
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
weight_tensor = expert_weights[layer][weight_idx]
|
||||
|
||||
for local_expert in range(num_local_experts):
|
||||
# Calculate the global expert ID for this local expert
|
||||
global_pos = ep_rank * num_local_experts + local_expert
|
||||
expected_logical_expert = new_indices[layer, global_pos].item()
|
||||
|
||||
# Check if the weights are correct
|
||||
actual_weights = weight_tensor[local_expert]
|
||||
expected_base = (expected_logical_expert * 1000 + layer * 100 +
|
||||
weight_idx * 10)
|
||||
expected_weights = torch.arange(expected_base,
|
||||
expected_base + hidden_size,
|
||||
device=actual_weights.device,
|
||||
dtype=actual_weights.dtype)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual_weights,
|
||||
expected_weights,
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"local expert {local_expert}: "
|
||||
f"weights do not match. "
|
||||
f"Expected logical expert {expected_logical_expert}")
|
||||
|
||||
|
||||
def verify_redundant_experts_have_same_weights(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
world_size: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
"""
|
||||
Verify that all replicas of the same logical expert have the same weights.
|
||||
"""
|
||||
num_layers = len(expert_weights)
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
|
||||
for layer in range(num_layers):
|
||||
# Collect weights for all physical experts for each weight matrix
|
||||
all_weights: list[torch.Tensor] = []
|
||||
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
# Create tensor to store all expert weights
|
||||
# Shape: [total_physical_experts, hidden_size]
|
||||
gathered_weights = torch.zeros(
|
||||
total_physical_experts,
|
||||
hidden_size,
|
||||
device=expert_weights[layer][weight_idx].device,
|
||||
dtype=expert_weights[layer][weight_idx].dtype)
|
||||
|
||||
# Use all_gather to collect expert weights from current node
|
||||
# expert_weights[layer][weight_idx] shape:
|
||||
# [num_local_experts, hidden_size]
|
||||
local_weights = expert_weights[layer][
|
||||
weight_idx] # [num_local_experts, hidden_size]
|
||||
|
||||
# Split tensor along dim 0 into a list for all_gather
|
||||
gathered_weights_list = torch.chunk(gathered_weights,
|
||||
world_size,
|
||||
dim=0)
|
||||
|
||||
torch.distributed.all_gather(
|
||||
# Output list: each element corresponds to one rank's weights
|
||||
list(gathered_weights_list),
|
||||
local_weights # Input: current rank's local weights
|
||||
)
|
||||
|
||||
all_weights.append(gathered_weights)
|
||||
|
||||
# Verify that all replicas of the same logical expert have the same
|
||||
# weights
|
||||
logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {}
|
||||
|
||||
for physical_pos in range(total_physical_experts):
|
||||
logical_expert_id = int(indices[layer, physical_pos].item())
|
||||
|
||||
if logical_expert_id not in logical_expert_weights:
|
||||
# First time encountering this logical expert, save its weights
|
||||
logical_expert_weights[logical_expert_id] = {
|
||||
weight_idx: all_weights[weight_idx][physical_pos]
|
||||
for weight_idx in range(len(hidden_sizes))
|
||||
}
|
||||
else:
|
||||
# Verify that current physical expert's weights match the
|
||||
# previously saved logical expert weights
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
all_weights[weight_idx][physical_pos],
|
||||
logical_expert_weights[logical_expert_id][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"logical expert {logical_expert_id}: "
|
||||
f"Physical expert {physical_pos} has different weights"
|
||||
f"than expected")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
# 2 GPU, 2 experts per GPU
|
||||
# 3 logical experts, 4 physical experts, 1 redundant experts
|
||||
(2, 1, 2, 3),
|
||||
# 2 GPU, 3 experts per GPU
|
||||
# 4 logical experts, 6 physical experts, 2 redundant experts
|
||||
(2, 2, 3, 4),
|
||||
# 2 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 16 physical experts, 0 redundant experts
|
||||
(2, 4, 8, 16),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 6 logical experts, 8 physical experts, 2 redundant experts
|
||||
(4, 1, 2, 6),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 5 logical experts, 8 physical experts, 3 redundant experts
|
||||
(4, 2, 2, 5),
|
||||
# 4 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||
(4, 8, 8, 16),
|
||||
])
|
||||
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
# Test parameters
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [32, 64] # Two different weight matrices
|
||||
|
||||
# Create old expert indices (with redundancy)
|
||||
redundancy_config = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
# Create new expert indices (with redundancy)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
# Create expert weights
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
old_indices)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_no_change(world_size):
|
||||
"""
|
||||
Test that when the indices do not change, the weights should remain
|
||||
unchanged.
|
||||
"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
num_layers = 2
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
||||
hidden_sizes = [32, 64]
|
||||
|
||||
# Create redundancy configuration
|
||||
redundancy_config = [2] * num_logical_experts
|
||||
|
||||
# Same indices - no change
|
||||
indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
redundancy_config)
|
||||
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
indices)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx} should remain "
|
||||
f"unchanged")
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
"""Test profile mode (should not copy actual weights)"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size,
|
||||
pipeline_model_parallel_size=1)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
num_layers = 1
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2
|
||||
hidden_sizes = [32]
|
||||
|
||||
# Create different index distributions
|
||||
old_redundancy = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
new_redundancy = create_redundancy_config(num_logical_experts,
|
||||
total_physical_experts)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
old_redundancy)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts,
|
||||
new_redundancy)
|
||||
|
||||
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
||||
hidden_sizes, ep_rank, device,
|
||||
old_indices)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=True # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged")
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
Reference in New Issue
Block a user