@ -61,6 +61,7 @@ class DPMetadata:
|
||||
# num_tokens_across_dp. If there's an incorrect ordering of ARs
|
||||
# across DP ranks, this tensor can end up containing the number
|
||||
# of padded tokens for a DP rank.
|
||||
|
||||
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1))
|
||||
|
||||
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
|
||||
|
||||
Reference in New Issue
Block a user