From 4917002523db90813a47ca5aed5cd22e2edb75f4 Mon Sep 17 00:00:00 2001 From: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> Date: Thu, 30 Oct 2025 12:26:27 -0700 Subject: [PATCH] [Fix] Skip `record_sleep_state` logic in `PrometheusStatsLogger` if not in dev mode (#27789) Signed-off-by: SumanthRH --- tests/basic_correctness/test_cumem.py | 43 ++++++++++++++++++++++++++- vllm/v1/metrics/loggers.py | 3 ++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 09f4ec03fb..0c037622f5 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio + import pytest import torch -from vllm import LLM, SamplingParams +from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams from vllm.device_allocator.cumem import CuMemAllocator from vllm.utils.mem_constants import GiB_bytes @@ -201,3 +203,42 @@ def test_deep_sleep(): # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text + + +@create_new_process_for_each_test() +def test_deep_sleep_async(): + async def test(): + model = "hmellor/tiny-random-LlamaForCausalLM" + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + engine_args = AsyncEngineArgs( + model=model, + enable_sleep_mode=True, + ) + + llm = AsyncLLMEngine.from_engine_args(engine_args) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1") + async for output in outputs: + pass + + # Put the engine to deep sleep + await llm.sleep(level=2) + + await llm.wake_up(tags=["weights"]) + await llm.collective_rpc("reload_weights") + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + await llm.wake_up(tags=["kv_cache"]) + outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2") + async for output2 in outputs2: + pass + + # cmp output + assert output.outputs[0].text == output2.outputs[0].text + + asyncio.run(test()) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3772f07066..67b6ceaa84 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1052,6 +1052,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() def record_sleep_state(self, sleep: int = 0, level: int = 0): + if not envs.VLLM_SERVER_DEV_MODE: + return + awake = 1 discard_all = 0 weights_offloaded = 0