low latency combine

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li
2025-10-20 22:00:01 -07:00
parent da26dce7b2
commit 99e2379b16

View File

@ -371,8 +371,8 @@ def run_low_latency():
all_topk_idx = torch.cat(all_topk_idx, dim=0)
all_topk_weights = torch.cat(all_topk_weights, dim=0)
# Verification
expert_tok_ids = [[] for _ in range(local_num_experts)]
expert_range_start = rank * local_num_experts
for i, cnt in enumerate(recv_expert_count):
for j in range(cnt):
@ -386,6 +386,19 @@ def run_low_latency():
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
# Combine
combined_hidden_states, event_overlap, hook = low_latency_combine(
recv_hidden_states, topk_idx, topk_weights, handle
)
hook()
torch.testing.assert_close(
combined_hidden_states.to(torch.float32),
x * topk_weights.sum(dim=-1).unsqueeze(-1),
atol=1e-2,
rtol=1e-2,
)
if __name__ == "__main__":
torch.distributed.init_process_group(