Compare commits

...

4 Commits

Author SHA1 Message Date
8a8b40d417 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:40:35 +00:00
c3f7afa6a8 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:39:45 +00:00
6cd8dec23f updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:29:24 +00:00
723263fa23 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-15 22:06:34 +00:00

View File

@ -10,6 +10,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from importlib import metadata
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
@ -42,16 +43,19 @@ EngineId = str
ReqId = str ReqId = str
GET_META_MSG = b"get_meta_msg" GET_META_MSG = b"get_meta_msg"
import os
VLLM_DEBUG_NIXL_XFER_TIME = os.getenv("VLLM_DEBUG_NIXL_XFER_TIME", "1") == "1"
logger = init_logger(__name__) logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try: try:
from nixl._api import nixl_agent as NixlWrapper from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config
logger.info("NIXL is available") NIXL_VERSION = metadata.version("nixl")
except ImportError: except ImportError:
logger.warning("NIXL is not available") logger.warning("NIXL is not available")
NixlWrapper = None NixlWrapper = None
NIXL_VERSION = None
class NixlAgentMetadata( class NixlAgentMetadata(
msgspec.Struct, msgspec.Struct,
@ -352,16 +356,20 @@ class NixlConnectorWorker:
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None: if NixlWrapper is None:
logger.error("NIXL is not available") logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available") raise RuntimeError("NIXL is not available.")
logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL v%s: worker %s", NIXL_VERSION, engine_id)
logger.info("Initializing NIXL worker %s", engine_id)
# Config. # Config.
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
# Agent. # Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) import os
NIXL_NUM_WORKERS = int(os.getenv("VLLM_NIXL_NUM_WORKERS", "8"))
logger.info(f"Using NIXL_NUM_WORKERS={NIXL_NUM_WORKERS} for NIXL agent.")
config = nixl_agent_config(enable_prog_thread=False, num_threads=NIXL_NUM_WORKERS)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -449,7 +457,8 @@ class NixlConnectorWorker:
def __del__(self): def __del__(self):
"""Cleanup background threads on destruction.""" """Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False) if t_ := getattr(self, "_handshake_initiation_executor", None):
t_.shutdown(wait=False)
if self._nixl_handshake_listener_t: if self._nixl_handshake_listener_t:
self._nixl_handshake_listener_t.join(timeout=0) self._nixl_handshake_listener_t.join(timeout=0)
@ -1019,10 +1028,16 @@ class NixlConnectorWorker:
remote_xfer_side_handle, remote_xfer_side_handle,
remote_block_descs_ids, remote_block_descs_ids,
notif_msg=notif_id, notif_msg=notif_id,
skip_desc_merge=True,
) )
# Begin async xfer. # Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer(handle) self.nixl_wrapper.transfer(handle)
end = time.perf_counter()
if VLLM_DEBUG_NIXL_XFER_TIME:
# Log the time taken for the transfer.
logger.info(f"TIME: {end - start}")
# Use handle to check completion in future step(). # Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time # TODO (NickLucche) surface xfer elapsed time