[Fix] Skip record_sleep_state logic in PrometheusStatsLogger if not in dev mode (#27789)

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
Sumanth R Hegde
2025-10-30 12:26:27 -07:00
committed by GitHub
parent a2981c4272
commit 4917002523
2 changed files with 45 additions and 1 deletions

View File

@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils.mem_constants import GiB_bytes
@ -201,3 +203,42 @@ def test_deep_sleep():
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
@create_new_process_for_each_test()
def test_deep_sleep_async():
async def test():
model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
engine_args = AsyncEngineArgs(
model=model,
enable_sleep_mode=True,
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
async for output in outputs:
pass
# Put the engine to deep sleep
await llm.sleep(level=2)
await llm.wake_up(tags=["weights"])
await llm.collective_rpc("reload_weights")
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
await llm.wake_up(tags=["kv_cache"])
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
async for output2 in outputs2:
pass
# cmp output
assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test())

View File

@ -1052,6 +1052,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
def record_sleep_state(self, sleep: int = 0, level: int = 0):
if not envs.VLLM_SERVER_DEV_MODE:
return
awake = 1
discard_all = 0
weights_offloaded = 0