Compare commits
7 Commits
amd_mori
...
debug-logg
| Author | SHA1 | Date | |
|---|---|---|---|
| f0945e311d | |||
| 4ec76caafa | |||
| 1588294a88 | |||
| e82e9afeb7 | |||
| 10abfaf309 | |||
| 9ff1a2b537 | |||
| 0abe10e4a7 |
@ -893,6 +893,7 @@ class NixlConnectorWorker:
|
|||||||
return done_req_ids
|
return done_req_ids
|
||||||
|
|
||||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||||
|
start = time.perf_counter()
|
||||||
"""
|
"""
|
||||||
Start loading by triggering non-blocking nixl_xfer.
|
Start loading by triggering non-blocking nixl_xfer.
|
||||||
We check for these trnxs to complete in each step().
|
We check for these trnxs to complete in each step().
|
||||||
@ -921,6 +922,11 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Add to requests that are waiting to be read and track expiration.
|
# Add to requests that are waiting to be read and track expiration.
|
||||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||||
|
end = time.perf_counter()
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
logger.info(
|
||||||
|
f"===== {len(metadata.reqs_to_recv)}: start_load_kv time: {end-start: 0.5f}s"
|
||||||
|
)
|
||||||
|
|
||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -1019,10 +1025,15 @@ class NixlConnectorWorker:
|
|||||||
remote_xfer_side_handle,
|
remote_xfer_side_handle,
|
||||||
remote_block_descs_ids,
|
remote_block_descs_ids,
|
||||||
notif_msg=notif_id,
|
notif_msg=notif_id,
|
||||||
|
skip_desc_merge=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Begin async xfer.
|
# Begin async xfer.
|
||||||
|
start = time.perf_counter()
|
||||||
self.nixl_wrapper.transfer(handle)
|
self.nixl_wrapper.transfer(handle)
|
||||||
|
end = time.perf_counter()
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
logger.info(f"TRANSFER TIME: {end-start :0.4f}s")
|
||||||
|
|
||||||
# Use handle to check completion in future step().
|
# Use handle to check completion in future step().
|
||||||
# TODO (NickLucche) surface xfer elapsed time
|
# TODO (NickLucche) surface xfer elapsed time
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||||
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -333,19 +334,20 @@ class Worker(WorkerBase):
|
|||||||
assert isinstance(output, IntermediateTensors)
|
assert isinstance(output, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(output.tensors,
|
get_pp_group().send_tensor_dict(output.tensors,
|
||||||
all_gather_group=get_tp_group())
|
all_gather_group=get_tp_group())
|
||||||
|
if not has_kv_transfer_group():
|
||||||
|
return None
|
||||||
|
|
||||||
# In case of PP with kv transfer, we need to pass through the
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
# finished_sending and finished_recving buffers.
|
# finished_sending and finished_recving buffers.
|
||||||
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
|
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
if output.finished_sending or output.finished_recving:
|
if output.finished_sending or output.finished_recving:
|
||||||
empty_output = copy.copy(empty_output)
|
new_output = copy.copy(new_output)
|
||||||
empty_output.finished_sending = output.finished_sending
|
new_output.finished_sending = output.finished_sending
|
||||||
empty_output.finished_recving = output.finished_recving
|
new_output.finished_recving = output.finished_recving
|
||||||
output = empty_output
|
output = new_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
# return output only from the driver worker
|
return output
|
||||||
return output if self.is_driver_worker else None
|
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
if self.profiler is None:
|
if self.profiler is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user