Files
vllm/vllm/v1/intermediates/intermediates_logging.py
2025-08-05 09:25:17 -07:00

406 lines
13 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Module for logging intermediate tensors during model execution.
This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global step counter
_CURRENT_STEP = 0
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
IL_MODULE_NAME = "_il_module_name"
IL_MODULE_CALL_IDX = "_il_module_call_idx"
# Utility functions for intermediate logging
def should_log_step(config):
"""Check if the current step should be logged based on the step IDs.
Args:
config: The IntermediateLoggingConfig instance.
Returns:
True if the current step should be logged, False otherwise.
"""
if not is_log_enabled(config):
return False
# If log_step_ids is empty, log all steps
if not config.log_step_ids:
return True
# Otherwise, check if current step is in the set of step IDs to log
return get_step() in config._step_id_set
def should_log_device(config, device_name):
"""Check if a device should be logged based on the device names.
Args:
config: The IntermediateLoggingConfig instance.
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
Returns:
True if the device should be logged, False otherwise.
If device_names is empty, all devices are logged.
"""
if not is_log_enabled(config):
return False
# If device_names is empty, log all devices
if not config.device_names:
return True
# Otherwise, check if device_name is in the list of device names to log
return device_name in config.device_names
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
"""Check if a module should be logged based on the name regex patterns.
Args:
config: The IntermediateLoggingConfig instance.
module_name: The name of the module to check.
Returns:
True if the module should be logged, False otherwise.
If no patterns are defined, all modules are logged.
If patterns are defined, the module is logged if it matches ANY pattern.
"""
if not is_log_enabled(config):
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
# Check if the module name matches any of the patterns
for pattern, call_idx in config._compiled_module_calls.items():
match = pattern.search(module_name)
if match:
logger.debug(
"Module %s, %s matches pattern: '%s', call_idx=%s",
module_name,
module.__class__.__name__,
pattern.pattern,
call_idx,
)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, call_idx)
return True
return False
def is_log_enabled(config):
if not config or not config.enabled:
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
return False
return True
def get_il_module_name(module: torch.nn.Module) -> str:
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
def get_il_module_call_idx(module: torch.nn.Module) -> int:
return getattr(module, IL_MODULE_CALL_IDX, -1)
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
setattr(module, IL_MODULE_NAME, name)
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
setattr(module, IL_MODULE_CALL_IDX, idx)
_global_config: Optional[IntermediateLoggingConfig] = None
@contextmanager
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
"""
Temporarily sets the global config for the duration of the context.
:param config: Keyword arguments to set as global config
"""
global _global_config
old_config = _global_config
try:
_global_config = config
yield
finally:
_global_config = old_config
def get_current_il_config():
return _global_config
def save_tensors(tensor: Any, file_path: str) -> Any:
"""Utility function to dump tensor to a file.
Args:
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
tensors, or a dictionary containing tensors.
file_path: Base path where to save the tensor (without extension).
"""
if isinstance(tensor, torch.Tensor):
device_name = str(tensor.device)
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
return tensor
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
torch.save(tensor, pt_path)
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
for i, item in enumerate(tensor):
save_tensors(item, f"{file_path}_{i}")
return tensor
if isinstance(tensor, dict):
for k, v in tensor.items():
save_tensors(v, f"{file_path}_{k}")
return tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if get_current_il_config() is None:
return
# Increment the global step counter
increment_step()
global _CURRENT_STEP_MODULE_CALL_STEP
_CURRENT_STEP_MODULE_CALL_STEP = {}
def _prepare_module_log_dir(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
dump_dir = Path(
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
suffix = ""
module_call_idx = get_current_step_module_call(module_name)
if module_call_idx > 0:
suffix = f"_{module_call_idx}"
module_dir = dump_dir / (module_name + suffix)
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name,
module_dir)
return module_dir
def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
file = (Path(intermediate_log_config.output_run_dir) /
f"step_{get_step()}" / "module_calls.txt")
with open(file, "a") as f:
f.write(f"{module_name}\n")
def update_current_step_module_call(module_name: str) -> None:
logger.debug("Updating current step module call for %s", module_name)
global _CURRENT_STEP_MODULE_CALL_STEP
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
else:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module,
is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
if not should_log_step(intermediate_log_config):
return None
module_name = get_il_module_name(module)
log_call_idx = get_il_module_call_idx(module)
current_call_idx = get_current_step_module_call(module_name)
should_log = True
if log_call_idx >= 0 and current_call_idx != log_call_idx:
should_log = False
log_dir = None
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
log_dir = _prepare_module_log_dir(intermediate_log_config,
module_name,
is_pre_fwd=is_pre_fwd)
return log_dir
def log_pre_fwd_hook(module: torch.nn.Module,
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
Returns:
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
save_tensors(inputs, str(log_dir / "inputs"))
return inputs
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to capture module outputs after forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, \
"IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
def get_step() -> int:
"""Get the current global step counter.
Returns:
The current global step counter.
"""
return _CURRENT_STEP
def increment_step() -> int:
"""Increment the global step counter.
Returns:
The new step counter value.
"""
global _CURRENT_STEP
_CURRENT_STEP += 1
return _CURRENT_STEP
def reset_step() -> None:
"""Reset the global step counter to zero."""
global _CURRENT_STEP
_CURRENT_STEP = 0
class IntermediatesLogger:
"""Class to manage logging of intermediate tensors during model
execution."""
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
Optional[RemovableHandle]]] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
logger.info("Intermediates will be logged in %s",
config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
Args:
model: The PyTorch model to register hooks for.
"""
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug("Registered pre_fwd hook for %s",
module.__class__.__name__)
post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug("Registered post_fwd hook for %s",
module.__class__.__name__)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
step_hook = model.register_forward_hook(step_fwd)
self.hooks.append(("", model, None, step_hook))
logger.info("Registered hooks for %s modules", len(self.hooks))
def remove_hooks(self) -> None:
"""Remove all registered hooks."""
for _, _, pre_hook, post_hook in self.hooks:
if pre_hook is not None:
pre_hook.remove()
if post_hook is not None:
post_hook.remove()
logger.info("Removed %s hooks", len(self.hooks))
self.hooks = []
def register_intermediate_hooks(
model: torch.nn.Module,
config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
Args:
model: The PyTorch model to log intermediates for.
config: Configuration for intermediate logging. If provided, this takes
precedence over kwargs.
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance