[Fix] Skip record_sleep_state logic in PrometheusStatsLogger if not in dev mode (#27789)
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@ -1,10 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
|
|
||||||
@ -201,3 +203,42 @@ def test_deep_sleep():
|
|||||||
|
|
||||||
# cmp output
|
# cmp output
|
||||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||||
|
|
||||||
|
|
||||||
|
@create_new_process_for_each_test()
|
||||||
|
def test_deep_sleep_async():
|
||||||
|
async def test():
|
||||||
|
model = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
used_bytes_baseline = total - free # in case other process is running
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model,
|
||||||
|
enable_sleep_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
prompt = "How are you?"
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||||
|
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
|
||||||
|
async for output in outputs:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Put the engine to deep sleep
|
||||||
|
await llm.sleep(level=2)
|
||||||
|
|
||||||
|
await llm.wake_up(tags=["weights"])
|
||||||
|
await llm.collective_rpc("reload_weights")
|
||||||
|
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||||
|
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||||
|
assert used_bytes < 4 * GiB_bytes
|
||||||
|
|
||||||
|
# now allocate kv cache and cuda graph memory
|
||||||
|
await llm.wake_up(tags=["kv_cache"])
|
||||||
|
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
|
||||||
|
async for output2 in outputs2:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# cmp output
|
||||||
|
assert output.outputs[0].text == output2.outputs[0].text
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
|||||||
@ -1052,6 +1052,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
|||||||
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
||||||
|
|
||||||
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
||||||
|
if not envs.VLLM_SERVER_DEV_MODE:
|
||||||
|
return
|
||||||
|
|
||||||
awake = 1
|
awake = 1
|
||||||
discard_all = 0
|
discard_all = 0
|
||||||
weights_offloaded = 0
|
weights_offloaded = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user