remove feature for metadata dump and input reload

Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang
2025-07-28 19:03:26 -07:00
parent d8bff253d7
commit 2af83ebdde
9 changed files with 278 additions and 544 deletions

View File

@ -49,7 +49,6 @@ The configuration file should be a JSON file with the following structure:
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
```
#### Configuration Parameters
| Parameter | Type | Description | Default |

View File

@ -5,9 +5,8 @@ Tests for the intermediate tensor logging functionality.
"""
import json
from os.path import isdir
import shutil
import os
import shutil
import tempfile
from pathlib import Path
from unittest import mock
@ -17,14 +16,10 @@ import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (get_current_il_config,
get_step, increment_step,
intermediate_logging,
register_intermediate_hooks,
reset_step,
should_log_device,
should_log_module,
should_log_step)
from vllm.v1.intermediates.intermediates_logging import (
get_current_il_config, get_step, increment_step, intermediate_logging,
register_intermediate_hooks, reset_step, should_log_device,
should_log_module, should_log_step)
class SimpleModel(nn.Module):
@ -237,7 +232,8 @@ def test_register_hooks(simple_model, il_config):
assert len(logger_instance.hooks) == 0
@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch(
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
@ -262,7 +258,6 @@ def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
# Remove hooks
logger_instance.remove_hooks()

View File

@ -7,27 +7,30 @@ This script compares the tensor outputs from two different intermediate logging
directories and generates a report of the differences.
Usage:
python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
python compare_intermediate.py --dir1 /path/to/first/log/dir \
--dir2 /path/to/second/log/dir [options]
Options:
--dir1 DIR First intermediate logging directory
--dir2 DIR Second intermediate logging directory
--output FILE Output file for the report (default: stdout)
--format {md,json} Output format (default: md)
--rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5)
--atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8)
--rtol FLOAT Relative tolerance for tensor comparison
(default: 1e-5)
--atol FLOAT Absolute tolerance for tensor comparison
(default: 1e-8)
--steps STEPS Comma-separated list of steps to compare (default: all)
--modules MODULES Comma-separated list of module name patterns to compare (default: all)
--modules MODULES Comma-separated list of module name patterns to compare
(default: all)
--verbose Include detailed information about each tensor
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Optional
import regex as re
import torch
@ -40,34 +43,19 @@ def load_tensor(path: Path) -> torch.Tensor:
return None
def load_json(path: Path) -> Dict:
"""Load a JSON file."""
try:
with open(path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Error loading JSON from {path}: {e}")
return {}
def extract_diff_metatada(exception_str: str) -> Dict:
def extract_diff_metatada(exception_str: str) -> dict:
try:
num_diff_elements = int(
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)
)
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1))
total_elements = int(
re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1)
)
re.search(r"Mismatched elements: \d+ / (\d+)",
exception_str).group(1))
max_abs_diff = float(
re.search(
r"Greatest absolute difference: ([\d\.e-]+)", exception_str
).group(1)
)
re.search(r"Greatest absolute difference: ([\d\.e-]+)",
exception_str).group(1))
max_rel_diff = float(
re.search(
r"Greatest relative difference: ([\d\.e-]+)", exception_str
).group(1)
)
re.search(r"Greatest relative difference: ([\d\.e-]+)",
exception_str).group(1))
return {
"num_diff_elements": num_diff_elements,
"total_elements": total_elements,
@ -78,9 +66,8 @@ def extract_diff_metatada(exception_str: str) -> Dict:
return {"error": exception_str}
def compare_tensors(
tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float
) -> Dict:
def compare_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float,
atol: float) -> dict:
"""Compare two tensors and return a dictionary with comparison results."""
if tensor1 is None or tensor2 is None:
return {"match": False, "error": "One or both tensors are None"}
@ -105,60 +92,8 @@ def compare_tensors(
return {"match": True}
def compare_json_values(value1: Any, value2: Any) -> Dict:
"""Compare two JSON values and return a dictionary with comparison results."""
if type(value1) is not type(value2):
return {
"match": False,
"error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
}
if isinstance(value1, dict):
# Compare dictionaries
all_keys = set(value1.keys()) | set(value2.keys())
mismatches = {}
for key in all_keys:
if key not in value1:
mismatches[key] = {"error": "Missing in first dict"}
elif key not in value2:
mismatches[key] = {"error": "Missing in second dict"}
else:
comparison = compare_json_values(value1[key], value2[key])
if not comparison["match"]:
mismatches[key] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
elif isinstance(value1, list):
# Compare lists
if len(value1) != len(value2):
return {
"match": False,
"error": f"Length mismatch: {len(value1)} vs {len(value2)}",
}
mismatches = {}
for i, (item1, item2) in enumerate(zip(value1, value2)):
comparison = compare_json_values(item1, item2)
if not comparison["match"]:
mismatches[i] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
else:
# Compare primitive values
if value1 == value2:
return {"match": True}
else:
return {"match": False, "value1": value1, "value2": value2}
def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
def find_tensor_files(
directory: Path) -> dict[str, dict[str, dict[str, list[Path]]]]:
"""
Find all tensor files in the given directory.
@ -198,23 +133,14 @@ def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Pat
if output_tensors:
result[step_name][module_name]["outputs"] = output_tensors
# Find JSON metadata files
inputs_json = module_dir / "inputs.json"
if inputs_json.exists():
result[step_name][module_name]["inputs_json"] = [inputs_json]
outputs_json = module_dir / "outputs.json"
if outputs_json.exists():
result[step_name][module_name]["outputs_json"] = [outputs_json]
return result
def filter_steps_and_modules(
tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]],
steps: Optional[List[str]] = None,
module_patterns: Optional[List[str]] = None,
) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
tensor_files: dict[str, dict[str, dict[str, list[Path]]]],
steps: Optional[list[str]] = None,
module_patterns: Optional[list[str]] = None,
) -> dict[str, dict[str, dict[str, list[Path]]]]:
"""Filter tensor files by steps and module patterns."""
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@ -223,11 +149,13 @@ def filter_steps_and_modules(
step_names = [f"step_{step}" for step in steps]
steps_to_include = {step: True for step in step_names}
else:
steps_to_include = {step: True for step in tensor_files.keys()}
steps_to_include = {step: True for step in tensor_files}
# Compile module patterns
if module_patterns:
compiled_patterns = [re.compile(pattern) for pattern in module_patterns]
compiled_patterns = [
re.compile(pattern) for pattern in module_patterns
]
else:
compiled_patterns = None
@ -237,11 +165,10 @@ def filter_steps_and_modules(
for module_name, file_types in modules.items():
# Check if module matches any pattern
if compiled_patterns:
if not any(
pattern.search(module_name) for pattern in compiled_patterns
):
continue
if compiled_patterns and not any(
pattern.search(module_name)
for pattern in compiled_patterns):
continue
result[step_name][module_name] = file_types
@ -253,9 +180,9 @@ def compare_directories(
dir2: Path,
rtol: Optional[float] = None,
atol: Optional[float] = None,
steps: Optional[List[str]] = None,
module_patterns: Optional[List[str]] = None,
) -> Dict:
steps: Optional[list[str]] = None,
module_patterns: Optional[list[str]] = None,
) -> dict:
"""Compare two intermediate logging directories and return a report."""
# Find tensor files in both directories
tensor_files1 = find_tensor_files(dir1)
@ -263,8 +190,10 @@ def compare_directories(
# Filter by steps and modules
if steps or module_patterns:
tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns)
tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns)
tensor_files1 = filter_steps_and_modules(tensor_files1, steps,
module_patterns)
tensor_files2 = filter_steps_and_modules(tensor_files2, steps,
module_patterns)
# Get all steps and modules from both directories
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
@ -296,12 +225,12 @@ def compare_directories(
# TODO: check if module calls txt exsits
dir1_module_call_file = dir1 / step / "module_calls.txt"
if dir1_module_call_file.exists():
with open(dir1 / step / "module_calls.txt", "r") as f:
with open(dir1 / step / "module_calls.txt") as f:
all_modules = f.read().splitlines()
else:
print(
"Warnings: the module call orders are missed, ordering using module alphbetics"
)
"Warnings: the module call orders are missed, ordering using "
"module alphbetics")
all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
step_report["module_call_list"] = []
for module in all_modules:
@ -329,32 +258,17 @@ def compare_directories(
step_report["modules"][module] = module_report
continue
# Compare JSON metadata
for json_type in ["inputs_json", "outputs_json"]:
json_files1 = modules1[module].get(json_type, [])
json_files2 = modules2[module].get(json_type, [])
if json_files1 and json_files2:
json1 = load_json(json_files1[0])
json2 = load_json(json_files2[0])
json_comparison = compare_json_values(json1, json2)
json_name = json_type.replace("_json", "")
module_report[f"{json_name}_metadata"] = json_comparison
# Add file paths for manual checking when there's a mismatch
if not json_comparison.get("match", True):
module_report[f"{json_name}_metadata"]["file1"] = str(
json_files1[0]
)
module_report[f"{json_name}_metadata"]["file2"] = str(
json_files2[0]
)
# Compare input tensors
input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])}
input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])}
all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys())
input_tensors1 = {
p.name: p
for p in modules1[module].get("inputs", [])
}
input_tensors2 = {
p.name: p
for p in modules2[module].get("inputs", [])
}
all_input_names = set(input_tensors1.keys()) | set(
input_tensors2.keys())
for tensor_name in sorted(all_input_names):
if tensor_name not in input_tensors1:
@ -389,9 +303,16 @@ def compare_directories(
module_report["summary"]["total_tensors"] += 1
# Compare output tensors
output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])}
output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])}
all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys())
output_tensors1 = {
p.name: p
for p in modules1[module].get("outputs", [])
}
output_tensors2 = {
p.name: p
for p in modules2[module].get("outputs", [])
}
all_output_names = set(output_tensors1.keys()) | set(
output_tensors2.keys())
for tensor_name in sorted(all_output_names):
if tensor_name not in output_tensors1:
@ -439,64 +360,58 @@ def compare_directories(
# Add overall summary
report["summary"] = {
"total_steps": len(all_steps),
"total_modules": sum(
step_report["summary"]["total_modules"]
for step_report in report["steps"].values()
),
"matching_modules": sum(
step_report["summary"]["matching_modules"]
for step_report in report["steps"].values()
),
"mismatched_modules": sum(
step_report["summary"]["mismatched_modules"]
for step_report in report["steps"].values()
),
"missing_modules": sum(
step_report["summary"]["missing_modules"]
for step_report in report["steps"].values()
),
"total_tensors": sum(
module_report["summary"]["total_tensors"]
"total_steps":
len(all_steps),
"total_modules":
sum(step_report["summary"]["total_modules"]
for step_report in report["steps"].values()),
"matching_modules":
sum(step_report["summary"]["matching_modules"]
for step_report in report["steps"].values()),
"mismatched_modules":
sum(step_report["summary"]["mismatched_modules"]
for step_report in report["steps"].values()),
"missing_modules":
sum(step_report["summary"]["missing_modules"]
for step_report in report["steps"].values()),
"total_tensors":
sum(module_report["summary"]["total_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"matching_tensors": sum(
module_report["summary"]["matching_tensors"]
if "summary" in module_report),
"matching_tensors":
sum(module_report["summary"]["matching_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"mismatched_tensors": sum(
module_report["summary"]["mismatched_tensors"]
if "summary" in module_report),
"mismatched_tensors":
sum(module_report["summary"]["mismatched_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"missing_tensors": sum(
module_report["summary"]["missing_tensors"]
if "summary" in module_report),
"missing_tensors":
sum(module_report["summary"]["missing_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
if "summary" in module_report),
}
return report
def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
def generate_markdown_report(report: dict, verbose: bool = False) -> str:
"""Generate a markdown report from the comparison results."""
lines = []
# Add header
lines.append("# Intermediate Logging Comparison Report")
lines.append("")
lines.append("Comparing intermediate logging outputs between:")
lines.append("Comparing intermediate logging outputs "
"between:")
lines.append(f"- **Directory 1**: `{report['dir1']}`")
lines.append(f"- **Directory 2**: `{report['dir2']}`")
lines.append("")
lines.append(f"Comparison parameters:")
lines.append("Comparison parameters:")
lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
lines.append(f"- Absolute tolerance (atol): {report['atol']}")
lines.append("")
@ -509,11 +424,13 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|----------|-------|----------|------------|---------|")
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
lines.append(
f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |"
)
f"| Modules | {summary['total_modules']} | "
f"{summary['matching_modules']} | {summary['mismatched_modules']} | "
f"{summary['missing_modules']} |")
lines.append(
f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |"
)
f"| Tensors | {summary['total_tensors']} | "
f"{summary['matching_tensors']} | {summary['mismatched_tensors']} | "
f"{summary['missing_tensors']} |")
lines.append("")
# Add step details
@ -523,8 +440,9 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append(f"## {step_name}")
lines.append("")
lines.append(
f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules"
)
f"**Summary**: {step_summary['matching_modules']} matching "
f"modules, {step_summary['mismatched_modules']} mismatched "
f"modules, {step_summary['missing_modules']} missing modules")
lines.append("")
# Add module details
@ -540,15 +458,14 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
module_summary = module_report["summary"]
# Determine module status
if module_summary["mismatched_tensors"] > 0:
status = ""
else:
status = ""
status = "" if module_summary["mismatched_tensors"] > 0 else ""
lines.append(f"### {status} {module_name}")
lines.append("")
lines.append(
f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors"
f"**Summary**: {module_summary['matching_tensors']} matching "
f"tensors, {module_summary['mismatched_tensors']} mismatched "
f"tensors, {module_summary['missing_tensors']} missing tensors"
)
lines.append("")
@ -558,20 +475,21 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
metadata_comparison = module_report[metadata_type]
if not metadata_comparison.get("match", True):
file_paths = ""
if (
"file1" in metadata_comparison
and "file2" in metadata_comparison
):
file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`"
if ("file1" in metadata_comparison
and "file2" in metadata_comparison):
file_paths = (
f" - Files: "
f"`{metadata_comparison['file1']}` "
f"vs `{metadata_comparison['file2']}`")
lines.append(
f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}"
)
f"**{metadata_type.capitalize()}**: Mismatch "
f"detected{file_paths}")
if verbose and "mismatches" in metadata_comparison:
lines.append("```json")
lines.append(
json.dumps(metadata_comparison["mismatches"], indent=2)
)
json.dumps(metadata_comparison["mismatches"],
indent=2))
lines.append("```")
lines.append("")
@ -585,8 +503,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
module_report["inputs"].items()
):
module_report["inputs"].items()):
if comparison.get("match", False):
status = ""
details = "Tensors match"
@ -595,13 +512,23 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"]
else:
status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, "
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
details = (
f"Max abs diff: "
f"{comparison.get('max_abs_diff', 'N/A')}, ")
details += (
f"Max relative diff: "
f"{comparison.get('max_rel_diff', 'N/A')}, ")
details += (
f"Diff elements: "
f"{comparison.get('num_diff_elements', 'N/A')}/"
f"{comparison.get('total_elements', 'N/A')}")
if "file1" in comparison and "file2" in comparison:
details += f"<br>Files: `{comparison['file1']}` vs `{comparison['file2']}`"
details += (
f"<br>Files: `{comparison['file1']}` vs "
f"`{comparison['file2']}`")
lines.append(f"| {tensor_name} | {status} | {details} |")
lines.append(
f"| {tensor_name} | {status} | {details} |")
lines.append("")
@ -613,8 +540,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
module_report["outputs"].items()
):
module_report["outputs"].items()):
if comparison.get("match", False):
status = ""
details = "Tensors match"
@ -623,11 +549,19 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"]
else:
status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, "
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
details = (
f"Max abs diff: "
f"{comparison.get('max_abs_diff', 'N/A')}, ")
details += (
f"Max relative diff: "
f"{comparison.get('max_rel_diff', 'N/A')}, ")
details += (
f"Diff elements: "
f"{comparison.get('num_diff_elements', 'N/A')}/"
f"{comparison.get('total_elements', 'N/A')}")
lines.append(f"| {tensor_name} | {status} | {details} |")
lines.append(
f"| {tensor_name} | {status} | {details} |")
lines.append("")
@ -636,15 +570,16 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
def main():
parser = argparse.ArgumentParser(
description="Compare intermediate logging outputs from two different runs."
)
parser.add_argument(
"--dir1", required=True, help="First intermediate logging directory"
)
parser.add_argument(
"--dir2", required=True, help="Second intermediate logging directory"
)
parser.add_argument("--output", help="Output file for the report (default: stdout)")
description=
"Compare intermediate logging outputs from two different runs.")
parser.add_argument("--dir1",
required=True,
help="First intermediate logging directory")
parser.add_argument("--dir2",
required=True,
help="Second intermediate logging directory")
parser.add_argument("--output",
help="Output file for the report (default: stdout)")
parser.add_argument(
"--rtol",
type=float,
@ -658,11 +593,12 @@ def main():
help="Absolute tolerance for tensor comparison (default: 1e-8)",
)
parser.add_argument(
"--steps", help="Comma-separated list of steps to compare (default: all)"
)
"--steps",
help="Comma-separated list of steps to compare (default: all)")
parser.add_argument(
"--modules",
help="Comma-separated list of module name patterns to compare (default: all)",
help="Comma-separated list of module name patterns to compare "
"(default: all)",
)
parser.add_argument(
"--verbose",

View File

@ -17,8 +17,7 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
from functools import cached_property
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, List, Set)
from re import Pattern
Protocol, TypeVar, Union, cast, get_args)
import regex as re
import torch
@ -4026,63 +4025,65 @@ class KVEventsConfig:
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging."""
output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors."""
reload_input_dir: Optional[str] = None
"""Directory where to load the inputs for the steps/modules.
This is used when we want to check per module numerical gaps instead
of accumulated gap to further dive into the actual numerical issues."""
module_call_match: Optional[List[str]] = None
module_call_match: Optional[list[str]] = None
"""Match modules by name regex and call index (
a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """
log_step_ids: List[int] = field(default_factory=lambda: [0])
log_step_ids: list[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps)."""
log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit)."""
enabled: bool = True
"""Whether logging is enabled."""
device_names: List[str] = field(default_factory=list)
device_names: list[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices)."""
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
init=False)
"""Compiled regex patterns for module filtering."""
_module_call: dict[str, int] = field(default_factory=dict, init=False)
_step_id_set: Set[int] = field(default_factory=set, init=False)
_step_id_set: set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result."""
def __post_init__(self):
"""Initialize derived fields after instance creation."""
self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids)
def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering."""
from vllm.logger import init_logger
logger = init_logger(__name__)
self._compiled_module_matches = []
if self.module_call_match is None:
logger.info("No module name regex patterns provided, will log all modules")
logger.info(
"No module name regex patterns provided, will log all modules")
return
# Compile all patterns
for regex_pattern_call_idx in self.module_call_match:
try:
@ -4091,15 +4092,16 @@ class IntermediateLoggingConfig:
call_idx = -1
if len(splits) > 1:
call_idx = int(splits[1])
compiled_pattern: Pattern[str] = re.compile(regex_pattern)
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
logger.info("Successfully compiled regex pattern: '%s'",
regex_pattern)
except Exception as e:
logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}")
raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e
logger.error("Failed to parse module_call_match '%s': %s",
regex_pattern_call_idx, e)
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
logger.info("Compiled %d regex patterns",
len(self._compiled_module_calls))
def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization."""
@ -4111,12 +4113,12 @@ class IntermediateLoggingConfig:
"enabled": self.enabled,
"device_names": self.device_names
}
@classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@property
def output_run_dir(self) -> str:
return self._output_run_dir
@ -4138,7 +4140,6 @@ class IntermediateLoggingConfig:
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
class CompilationLevel:

View File

@ -27,13 +27,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, IntermediateLoggingConfig,
KVEventsConfig, KVTransferConfig,
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs, get_field)
KVEventsConfig, KVTransferConfig, LoadConfig,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -400,7 +400,7 @@ class EngineArgs:
str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
@ -773,10 +773,13 @@ class EngineArgs:
default=None,
help="The configurations for intermediate loggings. Should be a "
"JSON string.")
intermediate_log_group.add_argument("--intermediate-log-config-path", type=str,
help="The path to the configurations for intermediate loggings. Should be a string.")
intermediate_log_group.add_argument(
"--intermediate-log-config-path",
type=str,
help="The path to the configurations for intermediate loggings. "
"Should be a string.")
# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
observability_group = parser.add_argument_group(
@ -865,9 +868,6 @@ class EngineArgs:
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--disable-log-stats',
action='store_true',
@ -979,11 +979,9 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
def create_intermediate_log_config(
self,
) -> Optional[IntermediateLoggingConfig]:
self, ) -> Optional[IntermediateLoggingConfig]:
"""Initializes and returns an IntermediateLoggingConfig object based on
`intermediate_log_config` or `intermediate_log_config_path`.
"""
@ -991,7 +989,7 @@ class EngineArgs:
return IntermediateLoggingConfig.from_dict(
self.intermediate_log_config)
if self.intermediate_log_config_path is not None:
with open(self.intermediate_log_config_path, "r") as f:
with open(self.intermediate_log_config_path) as f:
return IntermediateLoggingConfig.from_dict(json.load(f))
return None
@ -1235,8 +1233,7 @@ class EngineArgs:
disable_log_stats=self.disable_log_stats,
)
intermediate_log_config = self.create_intermediate_log_config(
)
intermediate_log_config = self.create_intermediate_log_config()
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid

View File

@ -7,8 +7,6 @@ This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
import json
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
@ -17,8 +15,6 @@ import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
# Import logger from vllm
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -91,7 +87,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
logger.debug("No patterns defined, will log module: %s", module_name)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
@ -115,7 +110,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
def is_log_enabled(config):
if not config or not config.enabled:
logger.debug("Not logging because config not enabled")
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
@ -161,151 +155,7 @@ def get_current_il_config():
return _global_config
def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any:
try:
# Convert inputs to JSON-serializable format
intermediates_json = convert_intermediates_to_json(intermediates)
with open(path, "w") as f:
json.dump(intermediates_json, f, indent=2)
logger.debug("Saved all intermediates as JSON to %s", path)
except Exception as e:
logger.warning("Failed to save intermediates as JSON: %s", e)
import traceback
logger.warning(traceback.format_exc())
def convert_intermediates_to_json(tensor: Any) -> Any:
"""Convert a intermediates(including tensor) to a JSON-serializable
representation.
Args:
intermediates: The intermediates to convert.
Returns:
A JSON-serializable representation of the tensor.
"""
if isinstance(tensor, torch.Tensor):
try:
result = {
"type": "tensor",
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"numel": tensor.numel(),
}
return result
except Exception as e:
# Handle any errors in tensor conversion
return {
"type": "tensor_error",
"error": str(e),
"tensor_type": str(type(tensor)),
}
elif isinstance(tensor, (list, tuple)):
# For lists/tuples, recursively convert each element
container_type = "list" if isinstance(tensor, list) else "tuple"
# If it's a large list, only include a sample
if len(tensor) > 20:
return {
"type": container_type,
"length": len(tensor),
"sample": [
convert_intermediates_to_json(item) for item in tensor[:100]
],
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": container_type,
"items": [convert_intermediates_to_json(item) for item in tensor],
}
elif isinstance(tensor, dict):
# For dictionaries, recursively convert each value
if len(tensor) > 20:
# For large dicts, only include keys and a sample of values
keys = list(tensor.keys())
sample_keys = keys[:20]
return {
"type": "dict",
"length": len(tensor),
"keys": keys,
"sample": {
k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
},
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": "dict",
"items": {
k: convert_intermediates_to_json(v) for k, v in tensor.items()
},
}
elif tensor is None:
return None
elif isinstance(tensor, (int, float, bool, str)):
# Primitive types can be directly serialized
return tensor
else:
# For other types, use string representation
return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
"""Utility function to dump tensor metadata to a file.
Args:
tensor: The tensor to dump.
file_path: Base path where to save the tensor (without extension).
"""
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None:
return False
if (
intermediate_log_config.max_tensor_size is not None
and tensor.numel() > intermediate_log_config.max_tensor_size
):
# Save tensor metadata instead of full tensor
tensor_info = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"device": str(tensor.device),
"numel": tensor.numel(),
"skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
f"{intermediate_log_config.max_tensor_size}",
}
os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
with open(f"{file_path}.json", "w") as f:
json.dump(tensor_info, f, indent=2)
return True
return False
def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
if reload_dir is None:
return None
try:
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None
replace_dir = str(intermediate_log_config.output_run_dir)
reload_path = save_path.replace(replace_dir, reload_dir)
logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
return torch.load(reload_path, map_location=tensor.device)
except Exception as e:
logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
return tensor
def save_tensors(
tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
) -> Any:
def save_tensors(tensor: Any, file_path: str) -> Any:
"""Utility function to dump tensor to a file.
Args:
@ -314,52 +164,32 @@ def save_tensors(
file_path: Base path where to save the tensor (without extension).
"""
# Also save the actual tensor data for tensors
if isinstance(tensor, torch.Tensor):
# Check if tensor is too large
if save_tensors_metadata_if_too_large(tensor, file_path):
return
# Get device name
device_name = str(tensor.device)
# Skip if device filtering is enabled and this device should not be
# logged
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
logger.debug(
"Skipping tensor on device %s due to device filter", device_name
)
return tensor
# Append device name to file path
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
# Save tensor directly without detaching or moving to CPU
torch.save(tensor, pt_path)
reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir)
if reloaded_tensor is not None:
return reloaded_tensor
logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
# For collections, also save each item individually
reloaded_inputs = []
for i, item in enumerate(tensor):
reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir)
reloaded_inputs.append(reloaded)
return tuple(reloaded_inputs) if reloaded_inputs else tensor
save_tensors(item, f"{file_path}_{i}")
return tensor
if isinstance(tensor, dict):
reloaded_inputs = {}
# For dictionaries, also save each value individually
for k, v in tensor.items():
reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir)
reloaded_inputs[k] = reloaded
return reloaded_inputs if reloaded_inputs else tensor
save_tensors(v, f"{file_path}_{k}")
return tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None:
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
@ -381,7 +211,8 @@ def _prepare_module_log_dir(
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir = Path(
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
@ -393,7 +224,8 @@ def _prepare_module_log_dir(
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir)
logger.debug("Logging module %s inputs/outputs to %s", module_name,
module_dir)
return module_dir
@ -401,13 +233,8 @@ def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
logger.debug("Logging module call for %s", module_name)
# write module name and call to step:
file = (
Path(intermediate_log_config.output_run_dir)
/ f"step_{get_step()}"
/ "module_calls.txt"
)
file = (Path(intermediate_log_config.output_run_dir) /
f"step_{get_step()}" / "module_calls.txt")
with open(file, "a") as f:
f.write(f"{module_name}\n")
@ -425,7 +252,8 @@ def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
def prepare_log_current_fwd(module,
is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
@ -443,15 +271,14 @@ def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
log_dir = _prepare_module_log_dir(
intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd
)
log_dir = _prepare_module_log_dir(intermediate_log_config,
module_name,
is_pre_fwd=is_pre_fwd)
return log_dir
def log_pre_fwd_hook(
module: torch.nn.Module, inputs: tuple[Any, ...]
) -> tuple[Any, ...]:
def log_pre_fwd_hook(module: torch.nn.Module,
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
@ -462,27 +289,12 @@ def log_pre_fwd_hook(
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
dump_intermediates_to_json(inputs, log_dir / "inputs.json")
intermediate_log_config = get_current_il_config()
if intermediate_log_config is not None:
reload_input_dir = getattr(
intermediate_log_config,
"reload_input_dir",
"/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
)
else:
reload_input_dir = None
reloaded_inputs = save_tensors(
inputs, str(log_dir / "inputs"), reload_input_dir
)
if reloaded_inputs is not None:
return reloaded_inputs
save_tensors(inputs, str(log_dir / "inputs"))
return inputs
def log_post_fwd_hook(
module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any
) -> None:
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to capture module outputs after forward pass.
Args:
@ -491,12 +303,11 @@ def log_post_fwd_hook(
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
dump_intermediates_to_json(outputs, log_dir / "outputs.json")
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, "IL config should not be None"
assert intermediate_log_config is not None, \
"IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
@ -532,14 +343,14 @@ class IntermediatesLogger:
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
self.hooks: list[
tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]]
] = []
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
Optional[RemovableHandle]]] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
logger.info("Intermediates will be logged in %s", config.output_run_dir)
logger.info("Intermediates will be logged in %s",
config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
@ -551,13 +362,11 @@ class IntermediatesLogger:
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug(
"Registered pre_fwd hook for %s", module.__class__.__name__
)
logger.debug("Registered pre_fwd hook for %s",
module.__class__.__name__)
post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug(
"Registered post_fwd hook for %s", module.__class__.__name__
)
logger.debug("Registered post_fwd hook for %s",
module.__class__.__name__)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
@ -578,7 +387,8 @@ class IntermediatesLogger:
def register_intermediate_hooks(
model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs
model: torch.nn.Module,
config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
@ -590,10 +400,6 @@ def register_intermediate_hooks(
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
if config is None:
# Create config from kwargs
config = IntermediateLoggingConfig.from_dict(kwargs)
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance

View File

@ -27,12 +27,12 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
logger = init_logger(__name__)

View File

@ -6,10 +6,11 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig, IntermediateLoggingConfig
from vllm.config import IntermediateLoggingConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.intermediates.intermediates_logging import (
register_intermediate_hooks)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
@ -64,27 +65,26 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def register_intermediate_hooks(self,
config: Optional[IntermediateLoggingConfig] = None,
**kwargs) -> None:
def register_intermediate_hooks(
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
"""Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core.
It registers hooks on the model to dump intermediate tensors during execution.
It registers hooks on the model to dump intermediate tensors during
execution.
Args:
config: Configuration for intermediate logging. If provided, this takes precedence over kwargs.
config: Configuration for intermediate logging. If provided, this
takes precedence over kwargs.
"""
if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: model_runner.model is not accessible")
if self.model_runner is None or not hasattr(
self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: "
"model_runner.model is not accessible")
return
model = self.model_runner.model
try:
# Register hooks
register_intermediate_hooks(model, config, **kwargs)
# Store the logger instance for potential later hook removal
except Exception as e:
logger.info("Successfully registered intermediate hooks")
logger.error("Error registering intermediate hooks", exc_info=True)
register_intermediate_hooks(model, config)
except Exception:
logger.exception("Error registering intermediate hooks")

View File

@ -128,21 +128,21 @@ class WorkerBase:
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def register_intermediate_hooks(self, config=None, **kwargs) -> None:
def register_intermediate_hooks(self, config=None) -> None:
"""Register hooks for intermediate tensor logging.
This method is a stub for v0 workers. The actual implementation is in v1 workers.
It's included here for compatibility with the collective_rpc mechanism.
This method is a stub for v0 workers. The actual implementation is
in v1 workers. It's included here for compatibility with the
collective_rpc mechanism.
Args:
config: Configuration for intermediate logging.
**kwargs: Configuration parameters for intermediate logging.
These are ignored in v0 workers.
"""
logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. "
"This is only available in v1 workers. No hooks will be registered.")
"This is only available in v1 workers. No hooks will be registered."
)
return None