@ -371,8 +371,8 @@ def run_low_latency():
|
||||
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||
|
||||
# Verification
|
||||
expert_tok_ids = [[] for _ in range(local_num_experts)]
|
||||
|
||||
expert_range_start = rank * local_num_experts
|
||||
for i, cnt in enumerate(recv_expert_count):
|
||||
for j in range(cnt):
|
||||
@ -386,6 +386,19 @@ def run_low_latency():
|
||||
|
||||
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
|
||||
|
||||
# Combine
|
||||
combined_hidden_states, event_overlap, hook = low_latency_combine(
|
||||
recv_hidden_states, topk_idx, topk_weights, handle
|
||||
)
|
||||
hook()
|
||||
|
||||
torch.testing.assert_close(
|
||||
combined_hidden_states.to(torch.float32),
|
||||
x * topk_weights.sum(dim=-1).unsqueeze(-1),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
|
||||
Reference in New Issue
Block a user