[V1] Reuse V0's memory_profiling util for gpu worker memory profiling (#19312)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi
2025-06-09 17:40:01 -07:00
committed by GitHub
parent 3a7cd627a8
commit cc867be19c
2 changed files with 51 additions and 53 deletions

View File

@ -2269,6 +2269,8 @@ def kill_process_tree(pid: int):
class MemorySnapshot:
"""Memory snapshot."""
torch_peak: int = 0
free_memory: int = 0
total_memory: int = 0
cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
@ -2288,8 +2290,8 @@ class MemorySnapshot:
self.torch_peak = torch.cuda.memory_stats().get(
"allocated_bytes.all.peak", 0)
self.cuda_memory = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
self.cuda_memory = self.total_memory - self.free_memory
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
@ -2302,6 +2304,8 @@ class MemorySnapshot:
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
total_memory=self.total_memory - other.total_memory,
cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
@ -2323,6 +2327,16 @@ class MemoryProfilingResult:
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
def __repr__(self) -> str:
return (f"Memory profiling takes {self.profile_time:.2f} seconds. "
f"Total non KV cache memory: "
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
f"torch peak memory increase: "
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
f"non-torch forward increase memory: "
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.")
@contextlib.contextmanager
def memory_profiling(

View File

@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
@ -130,20 +130,22 @@ class Worker(WorkerBase):
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
requested_memory = (total_gpu_memory *
self.cache_config.gpu_memory_utilization)
if self.init_gpu_memory < requested_memory:
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = (self.init_snapshot.total_memory *
self.cache_config.gpu_memory_utilization)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device ({GiB(self.init_gpu_memory)}/"
f"{GiB(total_gpu_memory)} GiB) on startup is less than "
f"desired GPU memory utilization "
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(requested_memory)} GiB). Decrease GPU memory "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
@ -192,57 +194,39 @@ class Worker(WorkerBase):
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
_, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
with memory_profiling(
self.init_snapshot,
weights_memory=int(
self.model_runner.model_memory_usage)) as profile_result:
self.model_runner.profile_run()
free_gpu_memory, _ = torch.cuda.mem_get_info()
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory/GiB_bytes} GiB, "
f"current free memory {free_gpu_memory/GiB_bytes} GiB. "
f"This happens when the GPU memory was not properly cleaned up "
f"before initializing the vLLM instance.")
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f"current free memory {GiB(free_gpu_memory)} GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container.")
available_kv_cache_memory = self.requested_memory \
- profile_result.non_kv_cache_memory
# Get the peak memory allocation recorded by torch
peak_torch_memory = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass.
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
# Reset after emptying torch cache
free_gpu_memory = torch.cuda.mem_get_info()[0]
# Total forward allocation (current) is equal to the diff in free memory
fwd_alloc_bytes = self.init_gpu_memory - free_gpu_memory
# We assume current non-torch allocation is equal to peak
non_torch_alloc_bytes = max(0, fwd_alloc_bytes - torch_allocated_bytes)
# Total forward allocation (peak) is peak torch + non-torch
peak_memory = peak_torch_memory + non_torch_alloc_bytes
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
GiB = lambda b: b / GiB_bytes
logger.debug(
"Initial free memory: %.2f GiB, free memory: %.2f GiB, "
"total GPU memory: %.2f GiB", GiB(self.init_gpu_memory),
GiB(free_gpu_memory), GiB(total_gpu_memory))
logger.debug(
"Peak torch memory: %.2f GiB, non-torch forward-pass memory: "
"%.2f GiB, available KVCache memory: %.2f GiB",
GiB(peak_torch_memory), GiB(non_torch_alloc_bytes),
GiB(available_kv_cache_memory))
"requested GPU memory: %.2f GiB",
GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
GiB(self.requested_memory))
logger.debug(profile_result)
logger.info("Available KV cache memory: %.2f GiB",
GiB(available_kv_cache_memory))
gc.collect()
return int(available_kv_cache_memory)