Compare commits
1 Commits
fix-precom
...
reduce_sca
| Author | SHA1 | Date | |
|---|---|---|---|
| 3679753af5 |
@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
|
|||||||
input_size[dim + 1:])
|
input_size[dim + 1:])
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
def reduce_scatter(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Note: This will produce an incorrect answer if we don't make
|
||||||
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||||
|
input_tensor = input_.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
assert input_tensor.shape[0] % world_size == 0
|
||||||
|
chunk_size = input_tensor.shape[0] // world_size
|
||||||
|
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||||
|
|
||||||
|
output_tensor = torch.empty(output_shape,
|
||||||
|
dtype=input_tensor.dtype,
|
||||||
|
device=input_tensor.device)
|
||||||
|
|
||||||
|
# Perform reduce-scatter operation
|
||||||
|
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||||
|
input_tensor,
|
||||||
|
group=self.device_group)
|
||||||
|
|
||||||
|
# Reshape before returning
|
||||||
|
return output_tensor.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
|
|||||||
@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
torch.distributed.all_reduce(out, group=self.device_group)
|
torch.distributed.all_reduce(out, group=self.device_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||||
|
world_size = self.world_size
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
assert pynccl_comm is not None
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Note: This will produce an incorrect answer if we don't make
|
||||||
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||||
|
input_tensor = input_.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
assert input_tensor.shape[0] % world_size == 0
|
||||||
|
chunk_size = input_tensor.shape[0] // world_size
|
||||||
|
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||||
|
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=input_tensor.dtype,
|
||||||
|
device=input_tensor.device)
|
||||||
|
|
||||||
|
pynccl_comm.reduce_scatter(output, input_)
|
||||||
|
|
||||||
|
# Reshape before returning
|
||||||
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
|||||||
@ -114,10 +114,26 @@ def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|||||||
return group._all_reduce_out_place(tensor)
|
return group._all_reduce_out_place(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
return group.reduce_scatter(tensor, dim)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
return torch.empty_like(tensor)
|
return torch.empty_like(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
new_shape = list(tensor.shape)
|
||||||
|
new_shape[dim] = tensor.shape[dim] // world_size
|
||||||
|
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_reduce",
|
op_name="all_reduce",
|
||||||
@ -126,6 +142,13 @@ if supports_custom_op():
|
|||||||
fake_impl=all_reduce_fake,
|
fake_impl=all_reduce_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="reduce_scatter",
|
||||||
|
op_func=reduce_scatter,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=reduce_scatter_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
@ -322,6 +345,18 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
return self.device_communicator.all_gather(input_, dim)
|
return self.device_communicator.all_gather(input_, dim)
|
||||||
|
|
||||||
|
def reduce_scatter(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
|
return self.device_communicator.reduce_scatter(input_, dim)
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user