remove feature for metadata dump and input reload
Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
@ -49,7 +49,6 @@ The configuration file should be a JSON file with the following structure:
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
|
||||
```
|
||||
|
||||
|
||||
#### Configuration Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|
||||
@ -5,9 +5,8 @@ Tests for the intermediate tensor logging functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from os.path import isdir
|
||||
import shutil
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
@ -17,14 +16,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.v1.intermediates.intermediates_logging import (get_current_il_config,
|
||||
get_step, increment_step,
|
||||
intermediate_logging,
|
||||
register_intermediate_hooks,
|
||||
reset_step,
|
||||
should_log_device,
|
||||
should_log_module,
|
||||
should_log_step)
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
get_current_il_config, get_step, increment_step, intermediate_logging,
|
||||
register_intermediate_hooks, reset_step, should_log_device,
|
||||
should_log_module, should_log_step)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
@ -237,7 +232,8 @@ def test_register_hooks(simple_model, il_config):
|
||||
assert len(logger_instance.hooks) == 0
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||
@mock.patch(
|
||||
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
|
||||
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||
il_config, temp_output_dir):
|
||||
@ -262,7 +258,6 @@ def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||
# Check that dump_intermediates_to_json and save_tensors were called
|
||||
assert mock_dump_json.called
|
||||
assert mock_save_tensors.called
|
||||
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
@ -7,27 +7,30 @@ This script compares the tensor outputs from two different intermediate logging
|
||||
directories and generates a report of the differences.
|
||||
|
||||
Usage:
|
||||
python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
|
||||
python compare_intermediate.py --dir1 /path/to/first/log/dir \
|
||||
--dir2 /path/to/second/log/dir [options]
|
||||
|
||||
Options:
|
||||
--dir1 DIR First intermediate logging directory
|
||||
--dir2 DIR Second intermediate logging directory
|
||||
--output FILE Output file for the report (default: stdout)
|
||||
--format {md,json} Output format (default: md)
|
||||
--rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5)
|
||||
--atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8)
|
||||
--rtol FLOAT Relative tolerance for tensor comparison
|
||||
(default: 1e-5)
|
||||
--atol FLOAT Absolute tolerance for tensor comparison
|
||||
(default: 1e-8)
|
||||
--steps STEPS Comma-separated list of steps to compare (default: all)
|
||||
--modules MODULES Comma-separated list of module name patterns to compare (default: all)
|
||||
--modules MODULES Comma-separated list of module name patterns to compare
|
||||
(default: all)
|
||||
--verbose Include detailed information about each tensor
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
|
||||
@ -40,34 +43,19 @@ def load_tensor(path: Path) -> torch.Tensor:
|
||||
return None
|
||||
|
||||
|
||||
def load_json(path: Path) -> Dict:
|
||||
"""Load a JSON file."""
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading JSON from {path}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def extract_diff_metatada(exception_str: str) -> Dict:
|
||||
def extract_diff_metatada(exception_str: str) -> dict:
|
||||
try:
|
||||
num_diff_elements = int(
|
||||
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)
|
||||
)
|
||||
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1))
|
||||
total_elements = int(
|
||||
re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1)
|
||||
)
|
||||
re.search(r"Mismatched elements: \d+ / (\d+)",
|
||||
exception_str).group(1))
|
||||
max_abs_diff = float(
|
||||
re.search(
|
||||
r"Greatest absolute difference: ([\d\.e-]+)", exception_str
|
||||
).group(1)
|
||||
)
|
||||
re.search(r"Greatest absolute difference: ([\d\.e-]+)",
|
||||
exception_str).group(1))
|
||||
max_rel_diff = float(
|
||||
re.search(
|
||||
r"Greatest relative difference: ([\d\.e-]+)", exception_str
|
||||
).group(1)
|
||||
)
|
||||
re.search(r"Greatest relative difference: ([\d\.e-]+)",
|
||||
exception_str).group(1))
|
||||
return {
|
||||
"num_diff_elements": num_diff_elements,
|
||||
"total_elements": total_elements,
|
||||
@ -78,9 +66,8 @@ def extract_diff_metatada(exception_str: str) -> Dict:
|
||||
return {"error": exception_str}
|
||||
|
||||
|
||||
def compare_tensors(
|
||||
tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float
|
||||
) -> Dict:
|
||||
def compare_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float,
|
||||
atol: float) -> dict:
|
||||
"""Compare two tensors and return a dictionary with comparison results."""
|
||||
if tensor1 is None or tensor2 is None:
|
||||
return {"match": False, "error": "One or both tensors are None"}
|
||||
@ -105,60 +92,8 @@ def compare_tensors(
|
||||
return {"match": True}
|
||||
|
||||
|
||||
def compare_json_values(value1: Any, value2: Any) -> Dict:
|
||||
"""Compare two JSON values and return a dictionary with comparison results."""
|
||||
if type(value1) is not type(value2):
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
|
||||
}
|
||||
|
||||
if isinstance(value1, dict):
|
||||
# Compare dictionaries
|
||||
all_keys = set(value1.keys()) | set(value2.keys())
|
||||
mismatches = {}
|
||||
|
||||
for key in all_keys:
|
||||
if key not in value1:
|
||||
mismatches[key] = {"error": "Missing in first dict"}
|
||||
elif key not in value2:
|
||||
mismatches[key] = {"error": "Missing in second dict"}
|
||||
else:
|
||||
comparison = compare_json_values(value1[key], value2[key])
|
||||
if not comparison["match"]:
|
||||
mismatches[key] = comparison
|
||||
|
||||
if mismatches:
|
||||
return {"match": False, "mismatches": mismatches}
|
||||
return {"match": True}
|
||||
|
||||
elif isinstance(value1, list):
|
||||
# Compare lists
|
||||
if len(value1) != len(value2):
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Length mismatch: {len(value1)} vs {len(value2)}",
|
||||
}
|
||||
|
||||
mismatches = {}
|
||||
for i, (item1, item2) in enumerate(zip(value1, value2)):
|
||||
comparison = compare_json_values(item1, item2)
|
||||
if not comparison["match"]:
|
||||
mismatches[i] = comparison
|
||||
|
||||
if mismatches:
|
||||
return {"match": False, "mismatches": mismatches}
|
||||
return {"match": True}
|
||||
|
||||
else:
|
||||
# Compare primitive values
|
||||
if value1 == value2:
|
||||
return {"match": True}
|
||||
else:
|
||||
return {"match": False, "value1": value1, "value2": value2}
|
||||
|
||||
|
||||
def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
|
||||
def find_tensor_files(
|
||||
directory: Path) -> dict[str, dict[str, dict[str, list[Path]]]]:
|
||||
"""
|
||||
Find all tensor files in the given directory.
|
||||
|
||||
@ -198,23 +133,14 @@ def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Pat
|
||||
if output_tensors:
|
||||
result[step_name][module_name]["outputs"] = output_tensors
|
||||
|
||||
# Find JSON metadata files
|
||||
inputs_json = module_dir / "inputs.json"
|
||||
if inputs_json.exists():
|
||||
result[step_name][module_name]["inputs_json"] = [inputs_json]
|
||||
|
||||
outputs_json = module_dir / "outputs.json"
|
||||
if outputs_json.exists():
|
||||
result[step_name][module_name]["outputs_json"] = [outputs_json]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_steps_and_modules(
|
||||
tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]],
|
||||
steps: Optional[List[str]] = None,
|
||||
module_patterns: Optional[List[str]] = None,
|
||||
) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
|
||||
tensor_files: dict[str, dict[str, dict[str, list[Path]]]],
|
||||
steps: Optional[list[str]] = None,
|
||||
module_patterns: Optional[list[str]] = None,
|
||||
) -> dict[str, dict[str, dict[str, list[Path]]]]:
|
||||
"""Filter tensor files by steps and module patterns."""
|
||||
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
|
||||
@ -223,11 +149,13 @@ def filter_steps_and_modules(
|
||||
step_names = [f"step_{step}" for step in steps]
|
||||
steps_to_include = {step: True for step in step_names}
|
||||
else:
|
||||
steps_to_include = {step: True for step in tensor_files.keys()}
|
||||
steps_to_include = {step: True for step in tensor_files}
|
||||
|
||||
# Compile module patterns
|
||||
if module_patterns:
|
||||
compiled_patterns = [re.compile(pattern) for pattern in module_patterns]
|
||||
compiled_patterns = [
|
||||
re.compile(pattern) for pattern in module_patterns
|
||||
]
|
||||
else:
|
||||
compiled_patterns = None
|
||||
|
||||
@ -237,11 +165,10 @@ def filter_steps_and_modules(
|
||||
|
||||
for module_name, file_types in modules.items():
|
||||
# Check if module matches any pattern
|
||||
if compiled_patterns:
|
||||
if not any(
|
||||
pattern.search(module_name) for pattern in compiled_patterns
|
||||
):
|
||||
continue
|
||||
if compiled_patterns and not any(
|
||||
pattern.search(module_name)
|
||||
for pattern in compiled_patterns):
|
||||
continue
|
||||
|
||||
result[step_name][module_name] = file_types
|
||||
|
||||
@ -253,9 +180,9 @@ def compare_directories(
|
||||
dir2: Path,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
module_patterns: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
steps: Optional[list[str]] = None,
|
||||
module_patterns: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
"""Compare two intermediate logging directories and return a report."""
|
||||
# Find tensor files in both directories
|
||||
tensor_files1 = find_tensor_files(dir1)
|
||||
@ -263,8 +190,10 @@ def compare_directories(
|
||||
|
||||
# Filter by steps and modules
|
||||
if steps or module_patterns:
|
||||
tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns)
|
||||
tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns)
|
||||
tensor_files1 = filter_steps_and_modules(tensor_files1, steps,
|
||||
module_patterns)
|
||||
tensor_files2 = filter_steps_and_modules(tensor_files2, steps,
|
||||
module_patterns)
|
||||
|
||||
# Get all steps and modules from both directories
|
||||
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
|
||||
@ -296,12 +225,12 @@ def compare_directories(
|
||||
# TODO: check if module calls txt exsits
|
||||
dir1_module_call_file = dir1 / step / "module_calls.txt"
|
||||
if dir1_module_call_file.exists():
|
||||
with open(dir1 / step / "module_calls.txt", "r") as f:
|
||||
with open(dir1 / step / "module_calls.txt") as f:
|
||||
all_modules = f.read().splitlines()
|
||||
else:
|
||||
print(
|
||||
"Warnings: the module call orders are missed, ordering using module alphbetics"
|
||||
)
|
||||
"Warnings: the module call orders are missed, ordering using "
|
||||
"module alphbetics")
|
||||
all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
|
||||
step_report["module_call_list"] = []
|
||||
for module in all_modules:
|
||||
@ -329,32 +258,17 @@ def compare_directories(
|
||||
step_report["modules"][module] = module_report
|
||||
continue
|
||||
|
||||
# Compare JSON metadata
|
||||
for json_type in ["inputs_json", "outputs_json"]:
|
||||
json_files1 = modules1[module].get(json_type, [])
|
||||
json_files2 = modules2[module].get(json_type, [])
|
||||
|
||||
if json_files1 and json_files2:
|
||||
json1 = load_json(json_files1[0])
|
||||
json2 = load_json(json_files2[0])
|
||||
|
||||
json_comparison = compare_json_values(json1, json2)
|
||||
json_name = json_type.replace("_json", "")
|
||||
module_report[f"{json_name}_metadata"] = json_comparison
|
||||
|
||||
# Add file paths for manual checking when there's a mismatch
|
||||
if not json_comparison.get("match", True):
|
||||
module_report[f"{json_name}_metadata"]["file1"] = str(
|
||||
json_files1[0]
|
||||
)
|
||||
module_report[f"{json_name}_metadata"]["file2"] = str(
|
||||
json_files2[0]
|
||||
)
|
||||
|
||||
# Compare input tensors
|
||||
input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])}
|
||||
input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])}
|
||||
all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys())
|
||||
input_tensors1 = {
|
||||
p.name: p
|
||||
for p in modules1[module].get("inputs", [])
|
||||
}
|
||||
input_tensors2 = {
|
||||
p.name: p
|
||||
for p in modules2[module].get("inputs", [])
|
||||
}
|
||||
all_input_names = set(input_tensors1.keys()) | set(
|
||||
input_tensors2.keys())
|
||||
|
||||
for tensor_name in sorted(all_input_names):
|
||||
if tensor_name not in input_tensors1:
|
||||
@ -389,9 +303,16 @@ def compare_directories(
|
||||
module_report["summary"]["total_tensors"] += 1
|
||||
|
||||
# Compare output tensors
|
||||
output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])}
|
||||
output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])}
|
||||
all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys())
|
||||
output_tensors1 = {
|
||||
p.name: p
|
||||
for p in modules1[module].get("outputs", [])
|
||||
}
|
||||
output_tensors2 = {
|
||||
p.name: p
|
||||
for p in modules2[module].get("outputs", [])
|
||||
}
|
||||
all_output_names = set(output_tensors1.keys()) | set(
|
||||
output_tensors2.keys())
|
||||
|
||||
for tensor_name in sorted(all_output_names):
|
||||
if tensor_name not in output_tensors1:
|
||||
@ -439,64 +360,58 @@ def compare_directories(
|
||||
|
||||
# Add overall summary
|
||||
report["summary"] = {
|
||||
"total_steps": len(all_steps),
|
||||
"total_modules": sum(
|
||||
step_report["summary"]["total_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"matching_modules": sum(
|
||||
step_report["summary"]["matching_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"mismatched_modules": sum(
|
||||
step_report["summary"]["mismatched_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"missing_modules": sum(
|
||||
step_report["summary"]["missing_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"total_tensors": sum(
|
||||
module_report["summary"]["total_tensors"]
|
||||
"total_steps":
|
||||
len(all_steps),
|
||||
"total_modules":
|
||||
sum(step_report["summary"]["total_modules"]
|
||||
for step_report in report["steps"].values()),
|
||||
"matching_modules":
|
||||
sum(step_report["summary"]["matching_modules"]
|
||||
for step_report in report["steps"].values()),
|
||||
"mismatched_modules":
|
||||
sum(step_report["summary"]["mismatched_modules"]
|
||||
for step_report in report["steps"].values()),
|
||||
"missing_modules":
|
||||
sum(step_report["summary"]["missing_modules"]
|
||||
for step_report in report["steps"].values()),
|
||||
"total_tensors":
|
||||
sum(module_report["summary"]["total_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"matching_tensors": sum(
|
||||
module_report["summary"]["matching_tensors"]
|
||||
if "summary" in module_report),
|
||||
"matching_tensors":
|
||||
sum(module_report["summary"]["matching_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"mismatched_tensors": sum(
|
||||
module_report["summary"]["mismatched_tensors"]
|
||||
if "summary" in module_report),
|
||||
"mismatched_tensors":
|
||||
sum(module_report["summary"]["mismatched_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"missing_tensors": sum(
|
||||
module_report["summary"]["missing_tensors"]
|
||||
if "summary" in module_report),
|
||||
"missing_tensors":
|
||||
sum(module_report["summary"]["missing_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
if "summary" in module_report),
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
def generate_markdown_report(report: dict, verbose: bool = False) -> str:
|
||||
"""Generate a markdown report from the comparison results."""
|
||||
lines = []
|
||||
|
||||
# Add header
|
||||
lines.append("# Intermediate Logging Comparison Report")
|
||||
lines.append("")
|
||||
lines.append("Comparing intermediate logging outputs between:")
|
||||
lines.append("Comparing intermediate logging outputs "
|
||||
"between:")
|
||||
lines.append(f"- **Directory 1**: `{report['dir1']}`")
|
||||
lines.append(f"- **Directory 2**: `{report['dir2']}`")
|
||||
lines.append("")
|
||||
lines.append(f"Comparison parameters:")
|
||||
lines.append("Comparison parameters:")
|
||||
lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
|
||||
lines.append(f"- Absolute tolerance (atol): {report['atol']}")
|
||||
lines.append("")
|
||||
@ -509,11 +424,13 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
lines.append("|----------|-------|----------|------------|---------|")
|
||||
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
|
||||
lines.append(
|
||||
f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |"
|
||||
)
|
||||
f"| Modules | {summary['total_modules']} | "
|
||||
f"{summary['matching_modules']} | {summary['mismatched_modules']} | "
|
||||
f"{summary['missing_modules']} |")
|
||||
lines.append(
|
||||
f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |"
|
||||
)
|
||||
f"| Tensors | {summary['total_tensors']} | "
|
||||
f"{summary['matching_tensors']} | {summary['mismatched_tensors']} | "
|
||||
f"{summary['missing_tensors']} |")
|
||||
lines.append("")
|
||||
|
||||
# Add step details
|
||||
@ -523,8 +440,9 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
lines.append(f"## {step_name}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules"
|
||||
)
|
||||
f"**Summary**: {step_summary['matching_modules']} matching "
|
||||
f"modules, {step_summary['mismatched_modules']} mismatched "
|
||||
f"modules, {step_summary['missing_modules']} missing modules")
|
||||
lines.append("")
|
||||
|
||||
# Add module details
|
||||
@ -540,15 +458,14 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
module_summary = module_report["summary"]
|
||||
|
||||
# Determine module status
|
||||
if module_summary["mismatched_tensors"] > 0:
|
||||
status = "❌"
|
||||
else:
|
||||
status = "✅"
|
||||
status = "❌" if module_summary["mismatched_tensors"] > 0 else "✅"
|
||||
|
||||
lines.append(f"### {status} {module_name}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors"
|
||||
f"**Summary**: {module_summary['matching_tensors']} matching "
|
||||
f"tensors, {module_summary['mismatched_tensors']} mismatched "
|
||||
f"tensors, {module_summary['missing_tensors']} missing tensors"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
@ -558,20 +475,21 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
metadata_comparison = module_report[metadata_type]
|
||||
if not metadata_comparison.get("match", True):
|
||||
file_paths = ""
|
||||
if (
|
||||
"file1" in metadata_comparison
|
||||
and "file2" in metadata_comparison
|
||||
):
|
||||
file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`"
|
||||
if ("file1" in metadata_comparison
|
||||
and "file2" in metadata_comparison):
|
||||
file_paths = (
|
||||
f" - Files: "
|
||||
f"`{metadata_comparison['file1']}` "
|
||||
f"vs `{metadata_comparison['file2']}`")
|
||||
|
||||
lines.append(
|
||||
f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}"
|
||||
)
|
||||
f"**{metadata_type.capitalize()}**: Mismatch "
|
||||
f"detected{file_paths}")
|
||||
if verbose and "mismatches" in metadata_comparison:
|
||||
lines.append("```json")
|
||||
lines.append(
|
||||
json.dumps(metadata_comparison["mismatches"], indent=2)
|
||||
)
|
||||
json.dumps(metadata_comparison["mismatches"],
|
||||
indent=2))
|
||||
lines.append("```")
|
||||
lines.append("")
|
||||
|
||||
@ -585,8 +503,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
lines.append("|--------|--------|---------|")
|
||||
|
||||
for tensor_name, comparison in sorted(
|
||||
module_report["inputs"].items()
|
||||
):
|
||||
module_report["inputs"].items()):
|
||||
if comparison.get("match", False):
|
||||
status = "✅"
|
||||
details = "Tensors match"
|
||||
@ -595,13 +512,23 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
details = comparison["error"]
|
||||
else:
|
||||
status = "❌"
|
||||
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, "
|
||||
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, "
|
||||
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
|
||||
details = (
|
||||
f"Max abs diff: "
|
||||
f"{comparison.get('max_abs_diff', 'N/A')}, ")
|
||||
details += (
|
||||
f"Max relative diff: "
|
||||
f"{comparison.get('max_rel_diff', 'N/A')}, ")
|
||||
details += (
|
||||
f"Diff elements: "
|
||||
f"{comparison.get('num_diff_elements', 'N/A')}/"
|
||||
f"{comparison.get('total_elements', 'N/A')}")
|
||||
if "file1" in comparison and "file2" in comparison:
|
||||
details += f"<br>Files: `{comparison['file1']}` vs `{comparison['file2']}`"
|
||||
details += (
|
||||
f"<br>Files: `{comparison['file1']}` vs "
|
||||
f"`{comparison['file2']}`")
|
||||
|
||||
lines.append(f"| {tensor_name} | {status} | {details} |")
|
||||
lines.append(
|
||||
f"| {tensor_name} | {status} | {details} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
@ -613,8 +540,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
lines.append("|--------|--------|---------|")
|
||||
|
||||
for tensor_name, comparison in sorted(
|
||||
module_report["outputs"].items()
|
||||
):
|
||||
module_report["outputs"].items()):
|
||||
if comparison.get("match", False):
|
||||
status = "✅"
|
||||
details = "Tensors match"
|
||||
@ -623,11 +549,19 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
details = comparison["error"]
|
||||
else:
|
||||
status = "❌"
|
||||
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, "
|
||||
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, "
|
||||
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
|
||||
details = (
|
||||
f"Max abs diff: "
|
||||
f"{comparison.get('max_abs_diff', 'N/A')}, ")
|
||||
details += (
|
||||
f"Max relative diff: "
|
||||
f"{comparison.get('max_rel_diff', 'N/A')}, ")
|
||||
details += (
|
||||
f"Diff elements: "
|
||||
f"{comparison.get('num_diff_elements', 'N/A')}/"
|
||||
f"{comparison.get('total_elements', 'N/A')}")
|
||||
|
||||
lines.append(f"| {tensor_name} | {status} | {details} |")
|
||||
lines.append(
|
||||
f"| {tensor_name} | {status} | {details} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
@ -636,15 +570,16 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare intermediate logging outputs from two different runs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir1", required=True, help="First intermediate logging directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir2", required=True, help="Second intermediate logging directory"
|
||||
)
|
||||
parser.add_argument("--output", help="Output file for the report (default: stdout)")
|
||||
description=
|
||||
"Compare intermediate logging outputs from two different runs.")
|
||||
parser.add_argument("--dir1",
|
||||
required=True,
|
||||
help="First intermediate logging directory")
|
||||
parser.add_argument("--dir2",
|
||||
required=True,
|
||||
help="Second intermediate logging directory")
|
||||
parser.add_argument("--output",
|
||||
help="Output file for the report (default: stdout)")
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
@ -658,11 +593,12 @@ def main():
|
||||
help="Absolute tolerance for tensor comparison (default: 1e-8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps", help="Comma-separated list of steps to compare (default: all)"
|
||||
)
|
||||
"--steps",
|
||||
help="Comma-separated list of steps to compare (default: all)")
|
||||
parser.add_argument(
|
||||
"--modules",
|
||||
help="Comma-separated list of module name patterns to compare (default: all)",
|
||||
help="Comma-separated list of module name patterns to compare "
|
||||
"(default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
|
||||
@ -17,8 +17,7 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args, List, Set)
|
||||
from re import Pattern
|
||||
Protocol, TypeVar, Union, cast, get_args)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -4026,63 +4025,65 @@ class KVEventsConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class IntermediateLoggingConfig:
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
|
||||
|
||||
output_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Directory where to save the intermediate tensors."""
|
||||
|
||||
|
||||
reload_input_dir: Optional[str] = None
|
||||
"""Directory where to load the inputs for the steps/modules.
|
||||
This is used when we want to check per module numerical gaps instead
|
||||
of accumulated gap to further dive into the actual numerical issues."""
|
||||
|
||||
module_call_match: Optional[List[str]] = None
|
||||
module_call_match: Optional[list[str]] = None
|
||||
"""Match modules by name regex and call index (
|
||||
a module can be called multiple times in a step)
|
||||
List of regex:call_idx, call_idx is -1 for default for all calls """
|
||||
|
||||
log_step_ids: List[int] = field(default_factory=lambda: [0])
|
||||
|
||||
log_step_ids: list[int] = field(default_factory=lambda: [0])
|
||||
"""List of step IDs to log (empty list means log all steps)."""
|
||||
|
||||
|
||||
log_post_fwd_inputs: bool = False
|
||||
"""Whether logging inputs after forwards for each module"""
|
||||
|
||||
max_tensor_size: Optional[int] = None
|
||||
"""Maximum number of elements in tensors to log (None = no limit)."""
|
||||
|
||||
|
||||
enabled: bool = True
|
||||
"""Whether logging is enabled."""
|
||||
device_names: List[str] = field(default_factory=list)
|
||||
device_names: list[str] = field(default_factory=list)
|
||||
"""List of device names to log (empty list means log all devices)."""
|
||||
|
||||
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
|
||||
|
||||
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
|
||||
init=False)
|
||||
"""Compiled regex patterns for module filtering."""
|
||||
|
||||
|
||||
_module_call: dict[str, int] = field(default_factory=dict, init=False)
|
||||
_step_id_set: Set[int] = field(default_factory=set, init=False)
|
||||
_step_id_set: set[int] = field(default_factory=set, init=False)
|
||||
"""Set of step IDs for faster lookup."""
|
||||
_output_run_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Unique directory to save single run/serve logging result."""
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize derived fields after instance creation."""
|
||||
self._compile_regex_patterns()
|
||||
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
|
||||
self._step_id_set = set(self.log_step_ids)
|
||||
|
||||
|
||||
def _compile_regex_patterns(self):
|
||||
"""Compile regex patterns for module name filtering."""
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
self._compiled_module_matches = []
|
||||
|
||||
|
||||
if self.module_call_match is None:
|
||||
logger.info("No module name regex patterns provided, will log all modules")
|
||||
logger.info(
|
||||
"No module name regex patterns provided, will log all modules")
|
||||
return
|
||||
|
||||
|
||||
# Compile all patterns
|
||||
for regex_pattern_call_idx in self.module_call_match:
|
||||
try:
|
||||
@ -4091,15 +4092,16 @@ class IntermediateLoggingConfig:
|
||||
call_idx = -1
|
||||
if len(splits) > 1:
|
||||
call_idx = int(splits[1])
|
||||
compiled_pattern: Pattern[str] = re.compile(regex_pattern)
|
||||
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
|
||||
self._compiled_module_calls[compiled_pattern] = call_idx
|
||||
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
|
||||
logger.info("Successfully compiled regex pattern: '%s'",
|
||||
regex_pattern)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}")
|
||||
raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e
|
||||
logger.error("Failed to parse module_call_match '%s': %s",
|
||||
regex_pattern_call_idx, e)
|
||||
|
||||
|
||||
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
|
||||
logger.info("Compiled %d regex patterns",
|
||||
len(self._compiled_module_calls))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the config to a dictionary for serialization."""
|
||||
@ -4111,12 +4113,12 @@ class IntermediateLoggingConfig:
|
||||
"enabled": self.enabled,
|
||||
"device_names": self.device_names
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
|
||||
@property
|
||||
def output_run_dir(self) -> str:
|
||||
return self._output_run_dir
|
||||
@ -4138,7 +4140,6 @@ class IntermediateLoggingConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
|
||||
@ -27,13 +27,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, IntermediateLoggingConfig,
|
||||
KVEventsConfig, KVTransferConfig,
|
||||
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||
KVEventsConfig, KVTransferConfig, LoadConfig,
|
||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -400,7 +400,7 @@ class EngineArgs:
|
||||
str] = ModelConfig.logits_processor_pattern
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = \
|
||||
ObservabilityConfig.show_hidden_metrics_for_version
|
||||
@ -773,10 +773,13 @@ class EngineArgs:
|
||||
default=None,
|
||||
help="The configurations for intermediate loggings. Should be a "
|
||||
"JSON string.")
|
||||
|
||||
intermediate_log_group.add_argument("--intermediate-log-config-path", type=str,
|
||||
help="The path to the configurations for intermediate loggings. Should be a string.")
|
||||
|
||||
|
||||
intermediate_log_group.add_argument(
|
||||
"--intermediate-log-config-path",
|
||||
type=str,
|
||||
help="The path to the configurations for intermediate loggings. "
|
||||
"Should be a string.")
|
||||
|
||||
# Observability arguments
|
||||
observability_kwargs = get_kwargs(ObservabilityConfig)
|
||||
observability_group = parser.add_argument_group(
|
||||
@ -865,9 +868,6 @@ class EngineArgs:
|
||||
vllm_group.add_argument("--additional-config",
|
||||
**vllm_kwargs["additional_config"])
|
||||
|
||||
|
||||
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
@ -979,11 +979,9 @@ class EngineArgs:
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
pt_load_map_location=self.pt_load_map_location,
|
||||
)
|
||||
|
||||
|
||||
def create_intermediate_log_config(
|
||||
self,
|
||||
) -> Optional[IntermediateLoggingConfig]:
|
||||
self, ) -> Optional[IntermediateLoggingConfig]:
|
||||
"""Initializes and returns an IntermediateLoggingConfig object based on
|
||||
`intermediate_log_config` or `intermediate_log_config_path`.
|
||||
"""
|
||||
@ -991,7 +989,7 @@ class EngineArgs:
|
||||
return IntermediateLoggingConfig.from_dict(
|
||||
self.intermediate_log_config)
|
||||
if self.intermediate_log_config_path is not None:
|
||||
with open(self.intermediate_log_config_path, "r") as f:
|
||||
with open(self.intermediate_log_config_path) as f:
|
||||
return IntermediateLoggingConfig.from_dict(json.load(f))
|
||||
return None
|
||||
|
||||
@ -1235,8 +1233,7 @@ class EngineArgs:
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
)
|
||||
|
||||
intermediate_log_config = self.create_intermediate_log_config(
|
||||
)
|
||||
intermediate_log_config = self.create_intermediate_log_config()
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
|
||||
@ -7,8 +7,6 @@ This module provides functionality to capture and save intermediate tensors
|
||||
(inputs and outputs) from PyTorch modules during forward passes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
@ -17,8 +15,6 @@ import torch
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
|
||||
# Import logger from vllm
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -91,7 +87,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||
return False
|
||||
# If no patterns are defined, log all modules
|
||||
if not config._compiled_module_calls:
|
||||
logger.debug("No patterns defined, will log module: %s", module_name)
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, -1)
|
||||
return True
|
||||
@ -115,7 +110,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||
|
||||
def is_log_enabled(config):
|
||||
if not config or not config.enabled:
|
||||
logger.debug("Not logging because config not enabled")
|
||||
return False
|
||||
if torch.compiler.is_compiling():
|
||||
logger.debug("Not logging because torch.compile is in progress")
|
||||
@ -161,151 +155,7 @@ def get_current_il_config():
|
||||
return _global_config
|
||||
|
||||
|
||||
def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any:
|
||||
try:
|
||||
# Convert inputs to JSON-serializable format
|
||||
intermediates_json = convert_intermediates_to_json(intermediates)
|
||||
with open(path, "w") as f:
|
||||
json.dump(intermediates_json, f, indent=2)
|
||||
logger.debug("Saved all intermediates as JSON to %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save intermediates as JSON: %s", e)
|
||||
import traceback
|
||||
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
|
||||
def convert_intermediates_to_json(tensor: Any) -> Any:
|
||||
"""Convert a intermediates(including tensor) to a JSON-serializable
|
||||
representation.
|
||||
|
||||
Args:
|
||||
intermediates: The intermediates to convert.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable representation of the tensor.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
try:
|
||||
result = {
|
||||
"type": "tensor",
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"numel": tensor.numel(),
|
||||
}
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Handle any errors in tensor conversion
|
||||
return {
|
||||
"type": "tensor_error",
|
||||
"error": str(e),
|
||||
"tensor_type": str(type(tensor)),
|
||||
}
|
||||
|
||||
elif isinstance(tensor, (list, tuple)):
|
||||
# For lists/tuples, recursively convert each element
|
||||
container_type = "list" if isinstance(tensor, list) else "tuple"
|
||||
|
||||
# If it's a large list, only include a sample
|
||||
if len(tensor) > 20:
|
||||
return {
|
||||
"type": container_type,
|
||||
"length": len(tensor),
|
||||
"sample": [
|
||||
convert_intermediates_to_json(item) for item in tensor[:100]
|
||||
],
|
||||
"note": f"Showing only first 20 of {len(tensor)} items",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": container_type,
|
||||
"items": [convert_intermediates_to_json(item) for item in tensor],
|
||||
}
|
||||
|
||||
elif isinstance(tensor, dict):
|
||||
# For dictionaries, recursively convert each value
|
||||
if len(tensor) > 20:
|
||||
# For large dicts, only include keys and a sample of values
|
||||
keys = list(tensor.keys())
|
||||
sample_keys = keys[:20]
|
||||
return {
|
||||
"type": "dict",
|
||||
"length": len(tensor),
|
||||
"keys": keys,
|
||||
"sample": {
|
||||
k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
|
||||
},
|
||||
"note": f"Showing only first 20 of {len(tensor)} items",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": "dict",
|
||||
"items": {
|
||||
k: convert_intermediates_to_json(v) for k, v in tensor.items()
|
||||
},
|
||||
}
|
||||
|
||||
elif tensor is None:
|
||||
return None
|
||||
|
||||
elif isinstance(tensor, (int, float, bool, str)):
|
||||
# Primitive types can be directly serialized
|
||||
return tensor
|
||||
|
||||
else:
|
||||
# For other types, use string representation
|
||||
return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
|
||||
|
||||
|
||||
def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
|
||||
"""Utility function to dump tensor metadata to a file.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to dump.
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None:
|
||||
return False
|
||||
if (
|
||||
intermediate_log_config.max_tensor_size is not None
|
||||
and tensor.numel() > intermediate_log_config.max_tensor_size
|
||||
):
|
||||
# Save tensor metadata instead of full tensor
|
||||
tensor_info = {
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"device": str(tensor.device),
|
||||
"numel": tensor.numel(),
|
||||
"skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
|
||||
f"{intermediate_log_config.max_tensor_size}",
|
||||
}
|
||||
os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
|
||||
with open(f"{file_path}.json", "w") as f:
|
||||
json.dump(tensor_info, f, indent=2)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
|
||||
if reload_dir is None:
|
||||
return None
|
||||
try:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None
|
||||
replace_dir = str(intermediate_log_config.output_run_dir)
|
||||
reload_path = save_path.replace(replace_dir, reload_dir)
|
||||
logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
|
||||
return torch.load(reload_path, map_location=tensor.device)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
|
||||
return tensor
|
||||
|
||||
|
||||
def save_tensors(
|
||||
tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
def save_tensors(tensor: Any, file_path: str) -> Any:
|
||||
"""Utility function to dump tensor to a file.
|
||||
|
||||
Args:
|
||||
@ -314,52 +164,32 @@ def save_tensors(
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
|
||||
# Also save the actual tensor data for tensors
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
# Check if tensor is too large
|
||||
if save_tensors_metadata_if_too_large(tensor, file_path):
|
||||
return
|
||||
# Get device name
|
||||
device_name = str(tensor.device)
|
||||
# Skip if device filtering is enabled and this device should not be
|
||||
# logged
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if not should_log_device(intermediate_log_config, device_name):
|
||||
logger.debug(
|
||||
"Skipping tensor on device %s due to device filter", device_name
|
||||
)
|
||||
return tensor
|
||||
# Append device name to file path
|
||||
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
|
||||
try:
|
||||
# Save tensor directly without detaching or moving to CPU
|
||||
torch.save(tensor, pt_path)
|
||||
reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir)
|
||||
if reloaded_tensor is not None:
|
||||
return reloaded_tensor
|
||||
logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
|
||||
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
|
||||
pt_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
|
||||
return tensor
|
||||
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
# For collections, also save each item individually
|
||||
|
||||
reloaded_inputs = []
|
||||
for i, item in enumerate(tensor):
|
||||
reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir)
|
||||
reloaded_inputs.append(reloaded)
|
||||
return tuple(reloaded_inputs) if reloaded_inputs else tensor
|
||||
save_tensors(item, f"{file_path}_{i}")
|
||||
return tensor
|
||||
if isinstance(tensor, dict):
|
||||
reloaded_inputs = {}
|
||||
# For dictionaries, also save each value individually
|
||||
for k, v in tensor.items():
|
||||
reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir)
|
||||
reloaded_inputs[k] = reloaded
|
||||
return reloaded_inputs if reloaded_inputs else tensor
|
||||
save_tensors(v, f"{file_path}_{k}")
|
||||
return tensor
|
||||
|
||||
|
||||
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None:
|
||||
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to increment the global step counter after a forward pass.
|
||||
|
||||
Args:
|
||||
@ -381,7 +211,8 @@ def _prepare_module_log_dir(
|
||||
is_pre_fwd: bool = False,
|
||||
) -> Path:
|
||||
# Create a unique directory for this step if not
|
||||
dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||
dump_dir = Path(
|
||||
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||
dump_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Create module directory
|
||||
@ -393,7 +224,8 @@ def _prepare_module_log_dir(
|
||||
if is_pre_fwd:
|
||||
_log_module_call(intermediate_log_config, module_name + suffix)
|
||||
module_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir)
|
||||
logger.debug("Logging module %s inputs/outputs to %s", module_name,
|
||||
module_dir)
|
||||
return module_dir
|
||||
|
||||
|
||||
@ -401,13 +233,8 @@ def _log_module_call(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
) -> None:
|
||||
logger.debug("Logging module call for %s", module_name)
|
||||
# write module name and call to step:
|
||||
file = (
|
||||
Path(intermediate_log_config.output_run_dir)
|
||||
/ f"step_{get_step()}"
|
||||
/ "module_calls.txt"
|
||||
)
|
||||
file = (Path(intermediate_log_config.output_run_dir) /
|
||||
f"step_{get_step()}" / "module_calls.txt")
|
||||
with open(file, "a") as f:
|
||||
f.write(f"{module_name}\n")
|
||||
|
||||
@ -425,7 +252,8 @@ def get_current_step_module_call(module_name: str) -> int:
|
||||
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
|
||||
|
||||
|
||||
def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
def prepare_log_current_fwd(module,
|
||||
is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None or not intermediate_log_config.enabled:
|
||||
return None
|
||||
@ -443,15 +271,14 @@ def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
if is_pre_fwd:
|
||||
update_current_step_module_call(module_name)
|
||||
if should_log:
|
||||
log_dir = _prepare_module_log_dir(
|
||||
intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd
|
||||
)
|
||||
log_dir = _prepare_module_log_dir(intermediate_log_config,
|
||||
module_name,
|
||||
is_pre_fwd=is_pre_fwd)
|
||||
return log_dir
|
||||
|
||||
|
||||
def log_pre_fwd_hook(
|
||||
module: torch.nn.Module, inputs: tuple[Any, ...]
|
||||
) -> tuple[Any, ...]:
|
||||
def log_pre_fwd_hook(module: torch.nn.Module,
|
||||
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""Hook to capture module inputs before forward pass.
|
||||
|
||||
Args:
|
||||
@ -462,27 +289,12 @@ def log_pre_fwd_hook(
|
||||
The unchanged inputs.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
|
||||
dump_intermediates_to_json(inputs, log_dir / "inputs.json")
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is not None:
|
||||
reload_input_dir = getattr(
|
||||
intermediate_log_config,
|
||||
"reload_input_dir",
|
||||
"/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
|
||||
)
|
||||
else:
|
||||
reload_input_dir = None
|
||||
reloaded_inputs = save_tensors(
|
||||
inputs, str(log_dir / "inputs"), reload_input_dir
|
||||
)
|
||||
if reloaded_inputs is not None:
|
||||
return reloaded_inputs
|
||||
save_tensors(inputs, str(log_dir / "inputs"))
|
||||
return inputs
|
||||
|
||||
|
||||
def log_post_fwd_hook(
|
||||
module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any
|
||||
) -> None:
|
||||
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to capture module outputs after forward pass.
|
||||
|
||||
Args:
|
||||
@ -491,12 +303,11 @@ def log_post_fwd_hook(
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
|
||||
dump_intermediates_to_json(outputs, log_dir / "outputs.json")
|
||||
save_tensors(outputs, str(log_dir / "outputs"))
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None, "IL config should not be None"
|
||||
assert intermediate_log_config is not None, \
|
||||
"IL config should not be None"
|
||||
if intermediate_log_config.log_post_fwd_inputs:
|
||||
dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
|
||||
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
|
||||
|
||||
|
||||
@ -532,14 +343,14 @@ class IntermediatesLogger:
|
||||
|
||||
def __init__(self, config: IntermediateLoggingConfig):
|
||||
self.config = config
|
||||
self.hooks: list[
|
||||
tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]]
|
||||
] = []
|
||||
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
|
||||
Optional[RemovableHandle]]] = []
|
||||
logger.debug("Created IntermediatesLogger with config: %s", config)
|
||||
path = Path(config.output_run_dir)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
# Log configuration
|
||||
logger.info("Intermediates will be logged in %s", config.output_run_dir)
|
||||
logger.info("Intermediates will be logged in %s",
|
||||
config.output_run_dir)
|
||||
|
||||
def register_hooks(self, model: torch.nn.Module) -> None:
|
||||
"""Register hooks for the model.
|
||||
@ -551,13 +362,11 @@ class IntermediatesLogger:
|
||||
for name, module in model.named_modules():
|
||||
if name and should_log_module(self.config, name, module):
|
||||
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
|
||||
logger.debug(
|
||||
"Registered pre_fwd hook for %s", module.__class__.__name__
|
||||
)
|
||||
logger.debug("Registered pre_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
post_hook = module.register_forward_hook(log_post_fwd_hook)
|
||||
logger.debug(
|
||||
"Registered post_fwd hook for %s", module.__class__.__name__
|
||||
)
|
||||
logger.debug("Registered post_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
self.hooks.append((name, module, pre_hook, post_hook))
|
||||
|
||||
# Register a step counter hook for the root model
|
||||
@ -578,7 +387,8 @@ class IntermediatesLogger:
|
||||
|
||||
|
||||
def register_intermediate_hooks(
|
||||
model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs
|
||||
model: torch.nn.Module,
|
||||
config: Optional[IntermediateLoggingConfig] = None
|
||||
) -> IntermediatesLogger:
|
||||
"""Register hooks to log intermediate tensors for a model.
|
||||
|
||||
@ -590,10 +400,6 @@ def register_intermediate_hooks(
|
||||
Returns:
|
||||
An IntermediatesLogger instance that can be used to manage the hooks.
|
||||
"""
|
||||
if config is None:
|
||||
# Create config from kwargs
|
||||
config = IntermediateLoggingConfig.from_dict(kwargs)
|
||||
|
||||
logger_instance = IntermediatesLogger(config)
|
||||
logger_instance.register_hooks(model)
|
||||
return logger_instance
|
||||
|
||||
@ -27,12 +27,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig, IntermediateLoggingConfig
|
||||
from vllm.config import IntermediateLoggingConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
register_intermediate_hooks)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -64,27 +65,26 @@ class WorkerBase(WorkerBaseV0):
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
|
||||
def register_intermediate_hooks(self,
|
||||
config: Optional[IntermediateLoggingConfig] = None,
|
||||
**kwargs) -> None:
|
||||
|
||||
def register_intermediate_hooks(
|
||||
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is called via collective_rpc from the engine core.
|
||||
It registers hooks on the model to dump intermediate tensors during execution.
|
||||
It registers hooks on the model to dump intermediate tensors during
|
||||
execution.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging. If provided, this takes precedence over kwargs.
|
||||
config: Configuration for intermediate logging. If provided, this
|
||||
takes precedence over kwargs.
|
||||
"""
|
||||
if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None:
|
||||
logger.error("Could not register intermediate hooks: model_runner.model is not accessible")
|
||||
if self.model_runner is None or not hasattr(
|
||||
self.model_runner, "model") or self.model_runner.model is None:
|
||||
logger.error("Could not register intermediate hooks: "
|
||||
"model_runner.model is not accessible")
|
||||
return
|
||||
model = self.model_runner.model
|
||||
try:
|
||||
# Register hooks
|
||||
register_intermediate_hooks(model, config, **kwargs)
|
||||
# Store the logger instance for potential later hook removal
|
||||
except Exception as e:
|
||||
logger.info("Successfully registered intermediate hooks")
|
||||
logger.error("Error registering intermediate hooks", exc_info=True)
|
||||
|
||||
register_intermediate_hooks(model, config)
|
||||
except Exception:
|
||||
logger.exception("Error registering intermediate hooks")
|
||||
|
||||
@ -128,21 +128,21 @@ class WorkerBase:
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def register_intermediate_hooks(self, config=None, **kwargs) -> None:
|
||||
|
||||
def register_intermediate_hooks(self, config=None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is a stub for v0 workers. The actual implementation is in v1 workers.
|
||||
It's included here for compatibility with the collective_rpc mechanism.
|
||||
This method is a stub for v0 workers. The actual implementation is
|
||||
in v1 workers. It's included here for compatibility with the
|
||||
collective_rpc mechanism.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging.
|
||||
**kwargs: Configuration parameters for intermediate logging.
|
||||
These are ignored in v0 workers.
|
||||
"""
|
||||
logger.warning(
|
||||
"register_intermediate_hooks is not implemented in v0 workers. "
|
||||
"This is only available in v1 workers. No hooks will be registered.")
|
||||
"This is only available in v1 workers. No hooks will be registered."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user