[Fix] Skip record_sleep_state logic in PrometheusStatsLogger if not in dev mode (#27789)
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@ -1,10 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
@ -201,3 +203,42 @@ def test_deep_sleep():
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep_async():
|
||||
async def test():
|
||||
model = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
enable_sleep_mode=True,
|
||||
)
|
||||
|
||||
llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
|
||||
async for output in outputs:
|
||||
pass
|
||||
|
||||
# Put the engine to deep sleep
|
||||
await llm.sleep(level=2)
|
||||
|
||||
await llm.wake_up(tags=["weights"])
|
||||
await llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
await llm.wake_up(tags=["kv_cache"])
|
||||
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
|
||||
async for output2 in outputs2:
|
||||
pass
|
||||
|
||||
# cmp output
|
||||
assert output.outputs[0].text == output2.outputs[0].text
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
@ -1052,6 +1052,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
||||
|
||||
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
|
||||
awake = 1
|
||||
discard_all = 0
|
||||
weights_offloaded = 0
|
||||
|
||||
Reference in New Issue
Block a user