diff --git a/docs/serving/data_parallel_deployment.md b/docs/serving/data_parallel_deployment.md index d12d1c9818..eff9c5d5e4 100644 --- a/docs/serving/data_parallel_deployment.md +++ b/docs/serving/data_parallel_deployment.md @@ -69,6 +69,7 @@ There are several notable differences when using Ray: - A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node - There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address` - There is no need to specify `--data-parallel-rpc-port` +- When a single DP group requires multiple nodes, *e.g.* in case a single model replica needs to run on at least two nodes, make sure to set `VLLM_RAY_DP_PACK_STRATEGY="span"` in which case `--data-parallel-size-local` is ignored and will be automatically determined - Remote DP ranks will be allocated based on node resources of the Ray cluster Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 917d0ec9f7..42bcd64ff1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1403,8 +1403,15 @@ class EngineArgs: "data_parallel_size_local must be set to use data_parallel_hybrid_lb." ) - # Local DP size defaults to global DP size if not set. - data_parallel_size_local = self.data_parallel_size + if self.data_parallel_backend == "ray" and ( + envs.VLLM_RAY_DP_PACK_STRATEGY == "span" + ): + # Data parallel size defaults to 1 if DP ranks are spanning + # multiple nodes + data_parallel_size_local = 1 + else: + # Otherwise local DP size defaults to global DP size if not set + data_parallel_size_local = self.data_parallel_size # DP address, used in multi-node case for torch distributed group # and ZMQ sockets. diff --git a/vllm/envs.py b/vllm/envs.py index 7dcfabe3e0..3cf3444e20 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -139,7 +139,7 @@ if TYPE_CHECKING: VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False - VLLM_RAY_DP_PACK_STRATEGY: str = "strict" + VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_V0_USE_OUTLINES_CACHE: bool = False @@ -1039,6 +1039,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # for non-master nodes, allocate as many DP ranks as can fit; # - "strict": # allocate exactly data-parallel-size-local DP ranks to each picked node; + # - "span": + # Should be used only when a single DP rank requires multiple nodes. + # allocate one DP rank over as many nodes as required for set world_size; # This environment variable is ignored if data-parallel-backend is not Ray. "VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv( "VLLM_RAY_DP_PACK_STRATEGY", "strict" diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 159b779111..ff7af7311c 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -345,6 +345,7 @@ class CoreEngineActorManager: world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] + dp_master_ip_key = f"node:{dp_master_ip}" nodes = sorted( available_resources.values(), key=lambda x: dp_master_ip_key not in x @@ -355,9 +356,25 @@ class CoreEngineActorManager: dp_master_ip, ) device_str = current_platform.ray_device_key + n_node_devices: list[int] = [ + int(node_resources[device_str]) + for node_resources in nodes + if device_str in node_resources + ] + assert n_node_devices, f"No {device_str} found in Ray cluster." + max_device_per_node = max(n_node_devices) + + pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY + _supported_pack_strategies = ("strict", "fill", "span") + if pack_strategy not in _supported_pack_strategies: + raise ValueError( + f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. " + "Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` " + f"to one of {_supported_pack_strategies}" + ) all2all_backend = vllm_config.parallel_config.all2all_backend - if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and ( + if pack_strategy == "fill" and ( all2all_backend == "deepep_high_throughput" or all2all_backend == "deepep_low_latency" ): @@ -367,12 +384,42 @@ class CoreEngineActorManager: "does not guarantee that. " "Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead." ) - logger.info( - "Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY", - envs.VLLM_RAY_DP_PACK_STRATEGY, - ) - strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict" + if pack_strategy in ("strict", "fill"): + placement_strategy = "STRICT_PACK" + else: + placement_strategy = "PACK" + assert world_size > max_device_per_node, ( + f"World size {world_size} is smaller than the " + "maximum number of devices per node " + f"{max_device_per_node}. Make sure to set " + "`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`" + ) + + # if we need multiple nodes per dp group, we require for now that + # available nodes are homogenous + assert set(n_node_devices) == {max_device_per_node}, ( + f"Nodes are not homogenous, {nodes}" + ) + assert world_size % max_device_per_node == 0, ( + f"For multi-node data parallel groups, world_size ({world_size}) must " + f"be a multiple of number of devices per node ({max_device_per_node})." + ) + assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, ( + f"Not enough total available nodes ({len(n_node_devices)}) " + f"and devices per node ({max_device_per_node}) " + f"to satisfy required world size {world_size} and data parallel size " + f"{dp_size}" + ) + assert dp_size_local == 1, ( + f"data-parallel-size-local {dp_size_local} should be set as the " + "default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. " + "The actual data-parallel-size-local will be auto determined." + ) + + # bundles collected for a single DP rank from multiple nodes, + # for "span" pack strategy + collected_bundles = [] for node_resources in nodes: node_ip_keys = [ key @@ -386,14 +433,14 @@ class CoreEngineActorManager: node_ip_key = node_ip_keys[0] node_ip = node_ip_key.split(":")[1] - # For now, each DP rank can only be assigned to one node - # TODO(rui): support allocating a single DP rank - # to multiple nodes - dp_size_available = ( - int(node_resources[device_str]) // world_size - if device_str in node_resources - else 0 - ) + n_device_on_node = int(node_resources.get(device_str, 0)) + if pack_strategy == "span" and n_device_on_node != 0: + # Strictly speaking, + # dp_size_available = n_device_on_node / world_size + # and is a fraction, but we use 1 for easier processing + dp_size_available = 1 + else: + dp_size_available = n_device_on_node // world_size if node_ip == dp_master_ip: if dp_size_available < dp_size_local: @@ -405,7 +452,7 @@ class CoreEngineActorManager: dp_size_available, ) dp_size_to_allocate = dp_size_local - elif strict_local_size: + elif pack_strategy == "strict": if dp_size_available < dp_size_local: logger.info( "Skipping node %s as %s DP ranks could not fit, " @@ -417,15 +464,31 @@ class CoreEngineActorManager: continue dp_size_to_allocate = dp_size_local else: + # for "pack_strategy" in "fill" and "span" + # we always take everything that's available dp_size_to_allocate = dp_size_available for i in range(dp_size_to_allocate): - bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [ - {"CPU": 1.0} - ] + device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}] + if pack_strategy == "span": + collected_bundles += device_bundle * n_device_on_node + assert len(collected_bundles) <= world_size, ( + "collected_bundles should be <= world_size, " + f"but got {len(collected_bundles)=} and {world_size=}" + ) + + # we only create a placement group if we collected enough devices + if len(collected_bundles) < world_size: + continue + + bundles = collected_bundles + [{"CPU": 1.0}] + collected_bundles = [] + else: + bundles = device_bundle * world_size + [{"CPU": 1.0}] + pg = ray.util.placement_group( name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", + strategy=placement_strategy, bundles=bundles, ) placement_groups.append(pg)