diff --git a/docs/contributing/intermediate_logging.md b/docs/contributing/intermediate_logging.md
index 4b1dc2aca8..fba4f439b6 100644
--- a/docs/contributing/intermediate_logging.md
+++ b/docs/contributing/intermediate_logging.md
@@ -49,7 +49,6 @@ The configuration file should be a JSON file with the following structure:
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
```
-
#### Configuration Parameters
| Parameter | Type | Description | Default |
diff --git a/tests/v1/test_intermediates_logging.py b/tests/v1/test_intermediates_logging.py
index a25a70e910..9d2d0110e9 100644
--- a/tests/v1/test_intermediates_logging.py
+++ b/tests/v1/test_intermediates_logging.py
@@ -5,9 +5,8 @@ Tests for the intermediate tensor logging functionality.
"""
import json
-from os.path import isdir
-import shutil
import os
+import shutil
import tempfile
from pathlib import Path
from unittest import mock
@@ -17,14 +16,10 @@ import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
-from vllm.v1.intermediates.intermediates_logging import (get_current_il_config,
- get_step, increment_step,
- intermediate_logging,
- register_intermediate_hooks,
- reset_step,
- should_log_device,
- should_log_module,
- should_log_step)
+from vllm.v1.intermediates.intermediates_logging import (
+ get_current_il_config, get_step, increment_step, intermediate_logging,
+ register_intermediate_hooks, reset_step, should_log_device,
+ should_log_module, should_log_step)
class SimpleModel(nn.Module):
@@ -237,7 +232,8 @@ def test_register_hooks(simple_model, il_config):
assert len(logger_instance.hooks) == 0
-@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
+@mock.patch(
+ 'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
@@ -262,7 +258,6 @@ def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
-
# Remove hooks
logger_instance.remove_hooks()
diff --git a/tools/compare_intermediate.py b/tools/compare_intermediate.py
index 984c604503..833ac38aad 100755
--- a/tools/compare_intermediate.py
+++ b/tools/compare_intermediate.py
@@ -7,27 +7,30 @@ This script compares the tensor outputs from two different intermediate logging
directories and generates a report of the differences.
Usage:
- python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
+ python compare_intermediate.py --dir1 /path/to/first/log/dir \
+ --dir2 /path/to/second/log/dir [options]
Options:
--dir1 DIR First intermediate logging directory
--dir2 DIR Second intermediate logging directory
--output FILE Output file for the report (default: stdout)
- --format {md,json} Output format (default: md)
- --rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5)
- --atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8)
+ --rtol FLOAT Relative tolerance for tensor comparison
+ (default: 1e-5)
+ --atol FLOAT Absolute tolerance for tensor comparison
+ (default: 1e-8)
--steps STEPS Comma-separated list of steps to compare (default: all)
- --modules MODULES Comma-separated list of module name patterns to compare (default: all)
+ --modules MODULES Comma-separated list of module name patterns to compare
+ (default: all)
--verbose Include detailed information about each tensor
"""
import argparse
import json
-import re
from collections import defaultdict
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Optional
+import regex as re
import torch
@@ -40,34 +43,19 @@ def load_tensor(path: Path) -> torch.Tensor:
return None
-def load_json(path: Path) -> Dict:
- """Load a JSON file."""
- try:
- with open(path, "r") as f:
- return json.load(f)
- except Exception as e:
- print(f"Error loading JSON from {path}: {e}")
- return {}
-
-
-def extract_diff_metatada(exception_str: str) -> Dict:
+def extract_diff_metatada(exception_str: str) -> dict:
try:
num_diff_elements = int(
- re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)
- )
+ re.search(r"Mismatched elements: (\d+) /", exception_str).group(1))
total_elements = int(
- re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1)
- )
+ re.search(r"Mismatched elements: \d+ / (\d+)",
+ exception_str).group(1))
max_abs_diff = float(
- re.search(
- r"Greatest absolute difference: ([\d\.e-]+)", exception_str
- ).group(1)
- )
+ re.search(r"Greatest absolute difference: ([\d\.e-]+)",
+ exception_str).group(1))
max_rel_diff = float(
- re.search(
- r"Greatest relative difference: ([\d\.e-]+)", exception_str
- ).group(1)
- )
+ re.search(r"Greatest relative difference: ([\d\.e-]+)",
+ exception_str).group(1))
return {
"num_diff_elements": num_diff_elements,
"total_elements": total_elements,
@@ -78,9 +66,8 @@ def extract_diff_metatada(exception_str: str) -> Dict:
return {"error": exception_str}
-def compare_tensors(
- tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float
-) -> Dict:
+def compare_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float,
+ atol: float) -> dict:
"""Compare two tensors and return a dictionary with comparison results."""
if tensor1 is None or tensor2 is None:
return {"match": False, "error": "One or both tensors are None"}
@@ -105,60 +92,8 @@ def compare_tensors(
return {"match": True}
-def compare_json_values(value1: Any, value2: Any) -> Dict:
- """Compare two JSON values and return a dictionary with comparison results."""
- if type(value1) is not type(value2):
- return {
- "match": False,
- "error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
- }
-
- if isinstance(value1, dict):
- # Compare dictionaries
- all_keys = set(value1.keys()) | set(value2.keys())
- mismatches = {}
-
- for key in all_keys:
- if key not in value1:
- mismatches[key] = {"error": "Missing in first dict"}
- elif key not in value2:
- mismatches[key] = {"error": "Missing in second dict"}
- else:
- comparison = compare_json_values(value1[key], value2[key])
- if not comparison["match"]:
- mismatches[key] = comparison
-
- if mismatches:
- return {"match": False, "mismatches": mismatches}
- return {"match": True}
-
- elif isinstance(value1, list):
- # Compare lists
- if len(value1) != len(value2):
- return {
- "match": False,
- "error": f"Length mismatch: {len(value1)} vs {len(value2)}",
- }
-
- mismatches = {}
- for i, (item1, item2) in enumerate(zip(value1, value2)):
- comparison = compare_json_values(item1, item2)
- if not comparison["match"]:
- mismatches[i] = comparison
-
- if mismatches:
- return {"match": False, "mismatches": mismatches}
- return {"match": True}
-
- else:
- # Compare primitive values
- if value1 == value2:
- return {"match": True}
- else:
- return {"match": False, "value1": value1, "value2": value2}
-
-
-def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
+def find_tensor_files(
+ directory: Path) -> dict[str, dict[str, dict[str, list[Path]]]]:
"""
Find all tensor files in the given directory.
@@ -198,23 +133,14 @@ def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Pat
if output_tensors:
result[step_name][module_name]["outputs"] = output_tensors
- # Find JSON metadata files
- inputs_json = module_dir / "inputs.json"
- if inputs_json.exists():
- result[step_name][module_name]["inputs_json"] = [inputs_json]
-
- outputs_json = module_dir / "outputs.json"
- if outputs_json.exists():
- result[step_name][module_name]["outputs_json"] = [outputs_json]
-
return result
def filter_steps_and_modules(
- tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]],
- steps: Optional[List[str]] = None,
- module_patterns: Optional[List[str]] = None,
-) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
+ tensor_files: dict[str, dict[str, dict[str, list[Path]]]],
+ steps: Optional[list[str]] = None,
+ module_patterns: Optional[list[str]] = None,
+) -> dict[str, dict[str, dict[str, list[Path]]]]:
"""Filter tensor files by steps and module patterns."""
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@@ -223,11 +149,13 @@ def filter_steps_and_modules(
step_names = [f"step_{step}" for step in steps]
steps_to_include = {step: True for step in step_names}
else:
- steps_to_include = {step: True for step in tensor_files.keys()}
+ steps_to_include = {step: True for step in tensor_files}
# Compile module patterns
if module_patterns:
- compiled_patterns = [re.compile(pattern) for pattern in module_patterns]
+ compiled_patterns = [
+ re.compile(pattern) for pattern in module_patterns
+ ]
else:
compiled_patterns = None
@@ -237,11 +165,10 @@ def filter_steps_and_modules(
for module_name, file_types in modules.items():
# Check if module matches any pattern
- if compiled_patterns:
- if not any(
- pattern.search(module_name) for pattern in compiled_patterns
- ):
- continue
+ if compiled_patterns and not any(
+ pattern.search(module_name)
+ for pattern in compiled_patterns):
+ continue
result[step_name][module_name] = file_types
@@ -253,9 +180,9 @@ def compare_directories(
dir2: Path,
rtol: Optional[float] = None,
atol: Optional[float] = None,
- steps: Optional[List[str]] = None,
- module_patterns: Optional[List[str]] = None,
-) -> Dict:
+ steps: Optional[list[str]] = None,
+ module_patterns: Optional[list[str]] = None,
+) -> dict:
"""Compare two intermediate logging directories and return a report."""
# Find tensor files in both directories
tensor_files1 = find_tensor_files(dir1)
@@ -263,8 +190,10 @@ def compare_directories(
# Filter by steps and modules
if steps or module_patterns:
- tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns)
- tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns)
+ tensor_files1 = filter_steps_and_modules(tensor_files1, steps,
+ module_patterns)
+ tensor_files2 = filter_steps_and_modules(tensor_files2, steps,
+ module_patterns)
# Get all steps and modules from both directories
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
@@ -296,12 +225,12 @@ def compare_directories(
# TODO: check if module calls txt exsits
dir1_module_call_file = dir1 / step / "module_calls.txt"
if dir1_module_call_file.exists():
- with open(dir1 / step / "module_calls.txt", "r") as f:
+ with open(dir1 / step / "module_calls.txt") as f:
all_modules = f.read().splitlines()
else:
print(
- "Warnings: the module call orders are missed, ordering using module alphbetics"
- )
+ "Warnings: the module call orders are missed, ordering using "
+ "module alphbetics")
all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
step_report["module_call_list"] = []
for module in all_modules:
@@ -329,32 +258,17 @@ def compare_directories(
step_report["modules"][module] = module_report
continue
- # Compare JSON metadata
- for json_type in ["inputs_json", "outputs_json"]:
- json_files1 = modules1[module].get(json_type, [])
- json_files2 = modules2[module].get(json_type, [])
-
- if json_files1 and json_files2:
- json1 = load_json(json_files1[0])
- json2 = load_json(json_files2[0])
-
- json_comparison = compare_json_values(json1, json2)
- json_name = json_type.replace("_json", "")
- module_report[f"{json_name}_metadata"] = json_comparison
-
- # Add file paths for manual checking when there's a mismatch
- if not json_comparison.get("match", True):
- module_report[f"{json_name}_metadata"]["file1"] = str(
- json_files1[0]
- )
- module_report[f"{json_name}_metadata"]["file2"] = str(
- json_files2[0]
- )
-
# Compare input tensors
- input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])}
- input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])}
- all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys())
+ input_tensors1 = {
+ p.name: p
+ for p in modules1[module].get("inputs", [])
+ }
+ input_tensors2 = {
+ p.name: p
+ for p in modules2[module].get("inputs", [])
+ }
+ all_input_names = set(input_tensors1.keys()) | set(
+ input_tensors2.keys())
for tensor_name in sorted(all_input_names):
if tensor_name not in input_tensors1:
@@ -389,9 +303,16 @@ def compare_directories(
module_report["summary"]["total_tensors"] += 1
# Compare output tensors
- output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])}
- output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])}
- all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys())
+ output_tensors1 = {
+ p.name: p
+ for p in modules1[module].get("outputs", [])
+ }
+ output_tensors2 = {
+ p.name: p
+ for p in modules2[module].get("outputs", [])
+ }
+ all_output_names = set(output_tensors1.keys()) | set(
+ output_tensors2.keys())
for tensor_name in sorted(all_output_names):
if tensor_name not in output_tensors1:
@@ -439,64 +360,58 @@ def compare_directories(
# Add overall summary
report["summary"] = {
- "total_steps": len(all_steps),
- "total_modules": sum(
- step_report["summary"]["total_modules"]
- for step_report in report["steps"].values()
- ),
- "matching_modules": sum(
- step_report["summary"]["matching_modules"]
- for step_report in report["steps"].values()
- ),
- "mismatched_modules": sum(
- step_report["summary"]["mismatched_modules"]
- for step_report in report["steps"].values()
- ),
- "missing_modules": sum(
- step_report["summary"]["missing_modules"]
- for step_report in report["steps"].values()
- ),
- "total_tensors": sum(
- module_report["summary"]["total_tensors"]
+ "total_steps":
+ len(all_steps),
+ "total_modules":
+ sum(step_report["summary"]["total_modules"]
+ for step_report in report["steps"].values()),
+ "matching_modules":
+ sum(step_report["summary"]["matching_modules"]
+ for step_report in report["steps"].values()),
+ "mismatched_modules":
+ sum(step_report["summary"]["mismatched_modules"]
+ for step_report in report["steps"].values()),
+ "missing_modules":
+ sum(step_report["summary"]["missing_modules"]
+ for step_report in report["steps"].values()),
+ "total_tensors":
+ sum(module_report["summary"]["total_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
- if "summary" in module_report
- ),
- "matching_tensors": sum(
- module_report["summary"]["matching_tensors"]
+ if "summary" in module_report),
+ "matching_tensors":
+ sum(module_report["summary"]["matching_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
- if "summary" in module_report
- ),
- "mismatched_tensors": sum(
- module_report["summary"]["mismatched_tensors"]
+ if "summary" in module_report),
+ "mismatched_tensors":
+ sum(module_report["summary"]["mismatched_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
- if "summary" in module_report
- ),
- "missing_tensors": sum(
- module_report["summary"]["missing_tensors"]
+ if "summary" in module_report),
+ "missing_tensors":
+ sum(module_report["summary"]["missing_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
- if "summary" in module_report
- ),
+ if "summary" in module_report),
}
return report
-def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
+def generate_markdown_report(report: dict, verbose: bool = False) -> str:
"""Generate a markdown report from the comparison results."""
lines = []
# Add header
lines.append("# Intermediate Logging Comparison Report")
lines.append("")
- lines.append("Comparing intermediate logging outputs between:")
+ lines.append("Comparing intermediate logging outputs "
+ "between:")
lines.append(f"- **Directory 1**: `{report['dir1']}`")
lines.append(f"- **Directory 2**: `{report['dir2']}`")
lines.append("")
- lines.append(f"Comparison parameters:")
+ lines.append("Comparison parameters:")
lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
lines.append(f"- Absolute tolerance (atol): {report['atol']}")
lines.append("")
@@ -509,11 +424,13 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|----------|-------|----------|------------|---------|")
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
lines.append(
- f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |"
- )
+ f"| Modules | {summary['total_modules']} | "
+ f"{summary['matching_modules']} | {summary['mismatched_modules']} | "
+ f"{summary['missing_modules']} |")
lines.append(
- f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |"
- )
+ f"| Tensors | {summary['total_tensors']} | "
+ f"{summary['matching_tensors']} | {summary['mismatched_tensors']} | "
+ f"{summary['missing_tensors']} |")
lines.append("")
# Add step details
@@ -523,8 +440,9 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append(f"## {step_name}")
lines.append("")
lines.append(
- f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules"
- )
+ f"**Summary**: {step_summary['matching_modules']} matching "
+ f"modules, {step_summary['mismatched_modules']} mismatched "
+ f"modules, {step_summary['missing_modules']} missing modules")
lines.append("")
# Add module details
@@ -540,15 +458,14 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
module_summary = module_report["summary"]
# Determine module status
- if module_summary["mismatched_tensors"] > 0:
- status = "❌"
- else:
- status = "✅"
+ status = "❌" if module_summary["mismatched_tensors"] > 0 else "✅"
lines.append(f"### {status} {module_name}")
lines.append("")
lines.append(
- f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors"
+ f"**Summary**: {module_summary['matching_tensors']} matching "
+ f"tensors, {module_summary['mismatched_tensors']} mismatched "
+ f"tensors, {module_summary['missing_tensors']} missing tensors"
)
lines.append("")
@@ -558,20 +475,21 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
metadata_comparison = module_report[metadata_type]
if not metadata_comparison.get("match", True):
file_paths = ""
- if (
- "file1" in metadata_comparison
- and "file2" in metadata_comparison
- ):
- file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`"
+ if ("file1" in metadata_comparison
+ and "file2" in metadata_comparison):
+ file_paths = (
+ f" - Files: "
+ f"`{metadata_comparison['file1']}` "
+ f"vs `{metadata_comparison['file2']}`")
lines.append(
- f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}"
- )
+ f"**{metadata_type.capitalize()}**: Mismatch "
+ f"detected{file_paths}")
if verbose and "mismatches" in metadata_comparison:
lines.append("```json")
lines.append(
- json.dumps(metadata_comparison["mismatches"], indent=2)
- )
+ json.dumps(metadata_comparison["mismatches"],
+ indent=2))
lines.append("```")
lines.append("")
@@ -585,8 +503,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
- module_report["inputs"].items()
- ):
+ module_report["inputs"].items()):
if comparison.get("match", False):
status = "✅"
details = "Tensors match"
@@ -595,13 +512,23 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"]
else:
status = "❌"
- details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, "
- details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, "
- details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
+ details = (
+ f"Max abs diff: "
+ f"{comparison.get('max_abs_diff', 'N/A')}, ")
+ details += (
+ f"Max relative diff: "
+ f"{comparison.get('max_rel_diff', 'N/A')}, ")
+ details += (
+ f"Diff elements: "
+ f"{comparison.get('num_diff_elements', 'N/A')}/"
+ f"{comparison.get('total_elements', 'N/A')}")
if "file1" in comparison and "file2" in comparison:
- details += f"
Files: `{comparison['file1']}` vs `{comparison['file2']}`"
+ details += (
+ f"
Files: `{comparison['file1']}` vs "
+ f"`{comparison['file2']}`")
- lines.append(f"| {tensor_name} | {status} | {details} |")
+ lines.append(
+ f"| {tensor_name} | {status} | {details} |")
lines.append("")
@@ -613,8 +540,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
- module_report["outputs"].items()
- ):
+ module_report["outputs"].items()):
if comparison.get("match", False):
status = "✅"
details = "Tensors match"
@@ -623,11 +549,19 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"]
else:
status = "❌"
- details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, "
- details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, "
- details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
+ details = (
+ f"Max abs diff: "
+ f"{comparison.get('max_abs_diff', 'N/A')}, ")
+ details += (
+ f"Max relative diff: "
+ f"{comparison.get('max_rel_diff', 'N/A')}, ")
+ details += (
+ f"Diff elements: "
+ f"{comparison.get('num_diff_elements', 'N/A')}/"
+ f"{comparison.get('total_elements', 'N/A')}")
- lines.append(f"| {tensor_name} | {status} | {details} |")
+ lines.append(
+ f"| {tensor_name} | {status} | {details} |")
lines.append("")
@@ -636,15 +570,16 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
def main():
parser = argparse.ArgumentParser(
- description="Compare intermediate logging outputs from two different runs."
- )
- parser.add_argument(
- "--dir1", required=True, help="First intermediate logging directory"
- )
- parser.add_argument(
- "--dir2", required=True, help="Second intermediate logging directory"
- )
- parser.add_argument("--output", help="Output file for the report (default: stdout)")
+ description=
+ "Compare intermediate logging outputs from two different runs.")
+ parser.add_argument("--dir1",
+ required=True,
+ help="First intermediate logging directory")
+ parser.add_argument("--dir2",
+ required=True,
+ help="Second intermediate logging directory")
+ parser.add_argument("--output",
+ help="Output file for the report (default: stdout)")
parser.add_argument(
"--rtol",
type=float,
@@ -658,11 +593,12 @@ def main():
help="Absolute tolerance for tensor comparison (default: 1e-8)",
)
parser.add_argument(
- "--steps", help="Comma-separated list of steps to compare (default: all)"
- )
+ "--steps",
+ help="Comma-separated list of steps to compare (default: all)")
parser.add_argument(
"--modules",
- help="Comma-separated list of module name patterns to compare (default: all)",
+ help="Comma-separated list of module name patterns to compare "
+ "(default: all)",
)
parser.add_argument(
"--verbose",
diff --git a/vllm/config.py b/vllm/config.py
index ab5676686d..1c6d4a87d0 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -17,8 +17,7 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
from functools import cached_property
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
- Protocol, TypeVar, Union, cast, get_args, List, Set)
-from re import Pattern
+ Protocol, TypeVar, Union, cast, get_args)
import regex as re
import torch
@@ -4026,63 +4025,65 @@ class KVEventsConfig:
@config
-@dataclass
+@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging."""
-
+
output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors."""
-
+
reload_input_dir: Optional[str] = None
"""Directory where to load the inputs for the steps/modules.
This is used when we want to check per module numerical gaps instead
of accumulated gap to further dive into the actual numerical issues."""
- module_call_match: Optional[List[str]] = None
+ module_call_match: Optional[list[str]] = None
"""Match modules by name regex and call index (
a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """
-
- log_step_ids: List[int] = field(default_factory=lambda: [0])
+
+ log_step_ids: list[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps)."""
-
+
log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit)."""
-
+
enabled: bool = True
"""Whether logging is enabled."""
- device_names: List[str] = field(default_factory=list)
+ device_names: list[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices)."""
-
- _compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
+
+ _compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
+ init=False)
"""Compiled regex patterns for module filtering."""
-
+
_module_call: dict[str, int] = field(default_factory=dict, init=False)
- _step_id_set: Set[int] = field(default_factory=set, init=False)
+ _step_id_set: set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result."""
-
+
def __post_init__(self):
"""Initialize derived fields after instance creation."""
self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids)
-
+
def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering."""
from vllm.logger import init_logger
logger = init_logger(__name__)
-
+
self._compiled_module_matches = []
-
+
if self.module_call_match is None:
- logger.info("No module name regex patterns provided, will log all modules")
+ logger.info(
+ "No module name regex patterns provided, will log all modules")
return
-
+
# Compile all patterns
for regex_pattern_call_idx in self.module_call_match:
try:
@@ -4091,15 +4092,16 @@ class IntermediateLoggingConfig:
call_idx = -1
if len(splits) > 1:
call_idx = int(splits[1])
- compiled_pattern: Pattern[str] = re.compile(regex_pattern)
+ compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx
- logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
+ logger.info("Successfully compiled regex pattern: '%s'",
+ regex_pattern)
except Exception as e:
- logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}")
- raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e
+ logger.error("Failed to parse module_call_match '%s': %s",
+ regex_pattern_call_idx, e)
-
- logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
+ logger.info("Compiled %d regex patterns",
+ len(self._compiled_module_calls))
def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization."""
@@ -4111,12 +4113,12 @@ class IntermediateLoggingConfig:
"enabled": self.enabled,
"device_names": self.device_names
}
-
+
@classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
-
+
@property
def output_run_dir(self) -> str:
return self._output_run_dir
@@ -4138,7 +4140,6 @@ class IntermediateLoggingConfig:
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
-
class CompilationLevel:
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 49e331e183..547f1bb810 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -27,13 +27,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, IntermediateLoggingConfig,
- KVEventsConfig, KVTransferConfig,
- LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
- ModelDType, ModelImpl, MultiModalConfig,
- ObservabilityConfig, ParallelConfig, PoolerConfig,
- PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
- SchedulerPolicy, SpeculativeConfig, TaskOption,
- TokenizerMode, VllmConfig, get_attr_docs, get_field)
+ KVEventsConfig, KVTransferConfig, LoadConfig,
+ LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
+ ModelImpl, MultiModalConfig, ObservabilityConfig,
+ ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
+ RunnerOption, SchedulerConfig, SchedulerPolicy,
+ SpeculativeConfig, TaskOption, TokenizerMode,
+ VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@@ -400,7 +400,7 @@ class EngineArgs:
str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None
-
+
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
@@ -773,10 +773,13 @@ class EngineArgs:
default=None,
help="The configurations for intermediate loggings. Should be a "
"JSON string.")
-
- intermediate_log_group.add_argument("--intermediate-log-config-path", type=str,
- help="The path to the configurations for intermediate loggings. Should be a string.")
-
+
+ intermediate_log_group.add_argument(
+ "--intermediate-log-config-path",
+ type=str,
+ help="The path to the configurations for intermediate loggings. "
+ "Should be a string.")
+
# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
observability_group = parser.add_argument_group(
@@ -865,9 +868,6 @@ class EngineArgs:
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
-
-
-
# Other arguments
parser.add_argument('--disable-log-stats',
action='store_true',
@@ -979,11 +979,9 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
-
def create_intermediate_log_config(
- self,
- ) -> Optional[IntermediateLoggingConfig]:
+ self, ) -> Optional[IntermediateLoggingConfig]:
"""Initializes and returns an IntermediateLoggingConfig object based on
`intermediate_log_config` or `intermediate_log_config_path`.
"""
@@ -991,7 +989,7 @@ class EngineArgs:
return IntermediateLoggingConfig.from_dict(
self.intermediate_log_config)
if self.intermediate_log_config_path is not None:
- with open(self.intermediate_log_config_path, "r") as f:
+ with open(self.intermediate_log_config_path) as f:
return IntermediateLoggingConfig.from_dict(json.load(f))
return None
@@ -1235,8 +1233,7 @@ class EngineArgs:
disable_log_stats=self.disable_log_stats,
)
- intermediate_log_config = self.create_intermediate_log_config(
- )
+ intermediate_log_config = self.create_intermediate_log_config()
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
diff --git a/vllm/v1/intermediates/intermediates_logging.py b/vllm/v1/intermediates/intermediates_logging.py
index 0024523b7b..ffa19ae221 100644
--- a/vllm/v1/intermediates/intermediates_logging.py
+++ b/vllm/v1/intermediates/intermediates_logging.py
@@ -7,8 +7,6 @@ This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
-import json
-import os
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
@@ -17,8 +15,6 @@ import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
-
-# Import logger from vllm
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -91,7 +87,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
- logger.debug("No patterns defined, will log module: %s", module_name)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
@@ -115,7 +110,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
def is_log_enabled(config):
if not config or not config.enabled:
- logger.debug("Not logging because config not enabled")
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
@@ -161,151 +155,7 @@ def get_current_il_config():
return _global_config
-def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any:
- try:
- # Convert inputs to JSON-serializable format
- intermediates_json = convert_intermediates_to_json(intermediates)
- with open(path, "w") as f:
- json.dump(intermediates_json, f, indent=2)
- logger.debug("Saved all intermediates as JSON to %s", path)
- except Exception as e:
- logger.warning("Failed to save intermediates as JSON: %s", e)
- import traceback
-
- logger.warning(traceback.format_exc())
-
-
-def convert_intermediates_to_json(tensor: Any) -> Any:
- """Convert a intermediates(including tensor) to a JSON-serializable
- representation.
-
- Args:
- intermediates: The intermediates to convert.
-
- Returns:
- A JSON-serializable representation of the tensor.
- """
- if isinstance(tensor, torch.Tensor):
- try:
- result = {
- "type": "tensor",
- "shape": list(tensor.shape),
- "dtype": str(tensor.dtype),
- "numel": tensor.numel(),
- }
-
- return result
- except Exception as e:
- # Handle any errors in tensor conversion
- return {
- "type": "tensor_error",
- "error": str(e),
- "tensor_type": str(type(tensor)),
- }
-
- elif isinstance(tensor, (list, tuple)):
- # For lists/tuples, recursively convert each element
- container_type = "list" if isinstance(tensor, list) else "tuple"
-
- # If it's a large list, only include a sample
- if len(tensor) > 20:
- return {
- "type": container_type,
- "length": len(tensor),
- "sample": [
- convert_intermediates_to_json(item) for item in tensor[:100]
- ],
- "note": f"Showing only first 20 of {len(tensor)} items",
- }
- else:
- return {
- "type": container_type,
- "items": [convert_intermediates_to_json(item) for item in tensor],
- }
-
- elif isinstance(tensor, dict):
- # For dictionaries, recursively convert each value
- if len(tensor) > 20:
- # For large dicts, only include keys and a sample of values
- keys = list(tensor.keys())
- sample_keys = keys[:20]
- return {
- "type": "dict",
- "length": len(tensor),
- "keys": keys,
- "sample": {
- k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
- },
- "note": f"Showing only first 20 of {len(tensor)} items",
- }
- else:
- return {
- "type": "dict",
- "items": {
- k: convert_intermediates_to_json(v) for k, v in tensor.items()
- },
- }
-
- elif tensor is None:
- return None
-
- elif isinstance(tensor, (int, float, bool, str)):
- # Primitive types can be directly serialized
- return tensor
-
- else:
- # For other types, use string representation
- return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
-
-
-def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
- """Utility function to dump tensor metadata to a file.
-
- Args:
- tensor: The tensor to dump.
- file_path: Base path where to save the tensor (without extension).
- """
- intermediate_log_config = get_current_il_config()
- if intermediate_log_config is None:
- return False
- if (
- intermediate_log_config.max_tensor_size is not None
- and tensor.numel() > intermediate_log_config.max_tensor_size
- ):
- # Save tensor metadata instead of full tensor
- tensor_info = {
- "shape": list(tensor.shape),
- "dtype": str(tensor.dtype),
- "device": str(tensor.device),
- "numel": tensor.numel(),
- "skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
- f"{intermediate_log_config.max_tensor_size}",
- }
- os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
- with open(f"{file_path}.json", "w") as f:
- json.dump(tensor_info, f, indent=2)
- return True
- return False
-
-
-def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
- if reload_dir is None:
- return None
- try:
- intermediate_log_config = get_current_il_config()
- assert intermediate_log_config is not None
- replace_dir = str(intermediate_log_config.output_run_dir)
- reload_path = save_path.replace(replace_dir, reload_dir)
- logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
- return torch.load(reload_path, map_location=tensor.device)
- except Exception as e:
- logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
- return tensor
-
-
-def save_tensors(
- tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
-) -> Any:
+def save_tensors(tensor: Any, file_path: str) -> Any:
"""Utility function to dump tensor to a file.
Args:
@@ -314,52 +164,32 @@ def save_tensors(
file_path: Base path where to save the tensor (without extension).
"""
- # Also save the actual tensor data for tensors
if isinstance(tensor, torch.Tensor):
- # Check if tensor is too large
- if save_tensors_metadata_if_too_large(tensor, file_path):
- return
- # Get device name
device_name = str(tensor.device)
- # Skip if device filtering is enabled and this device should not be
- # logged
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
- logger.debug(
- "Skipping tensor on device %s due to device filter", device_name
- )
return tensor
- # Append device name to file path
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
- # Save tensor directly without detaching or moving to CPU
torch.save(tensor, pt_path)
- reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir)
- if reloaded_tensor is not None:
- return reloaded_tensor
- logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
+ logger.debug("Saved tensor of shape %s to %s", tensor.shape,
+ pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
- # For collections, also save each item individually
-
- reloaded_inputs = []
for i, item in enumerate(tensor):
- reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir)
- reloaded_inputs.append(reloaded)
- return tuple(reloaded_inputs) if reloaded_inputs else tensor
+ save_tensors(item, f"{file_path}_{i}")
+ return tensor
if isinstance(tensor, dict):
- reloaded_inputs = {}
- # For dictionaries, also save each value individually
for k, v in tensor.items():
- reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir)
- reloaded_inputs[k] = reloaded
- return reloaded_inputs if reloaded_inputs else tensor
+ save_tensors(v, f"{file_path}_{k}")
+ return tensor
-def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None:
+def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
+ outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
@@ -381,7 +211,8 @@ def _prepare_module_log_dir(
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
- dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}"
+ dump_dir = Path(
+ intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
@@ -393,7 +224,8 @@ def _prepare_module_log_dir(
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
- logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir)
+ logger.debug("Logging module %s inputs/outputs to %s", module_name,
+ module_dir)
return module_dir
@@ -401,13 +233,8 @@ def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
- logger.debug("Logging module call for %s", module_name)
- # write module name and call to step:
- file = (
- Path(intermediate_log_config.output_run_dir)
- / f"step_{get_step()}"
- / "module_calls.txt"
- )
+ file = (Path(intermediate_log_config.output_run_dir) /
+ f"step_{get_step()}" / "module_calls.txt")
with open(file, "a") as f:
f.write(f"{module_name}\n")
@@ -425,7 +252,8 @@ def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
-def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
+def prepare_log_current_fwd(module,
+ is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
@@ -443,15 +271,14 @@ def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
- log_dir = _prepare_module_log_dir(
- intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd
- )
+ log_dir = _prepare_module_log_dir(intermediate_log_config,
+ module_name,
+ is_pre_fwd=is_pre_fwd)
return log_dir
-def log_pre_fwd_hook(
- module: torch.nn.Module, inputs: tuple[Any, ...]
-) -> tuple[Any, ...]:
+def log_pre_fwd_hook(module: torch.nn.Module,
+ inputs: tuple[Any, ...]) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
@@ -462,27 +289,12 @@ def log_pre_fwd_hook(
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
- dump_intermediates_to_json(inputs, log_dir / "inputs.json")
- intermediate_log_config = get_current_il_config()
- if intermediate_log_config is not None:
- reload_input_dir = getattr(
- intermediate_log_config,
- "reload_input_dir",
- "/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
- )
- else:
- reload_input_dir = None
- reloaded_inputs = save_tensors(
- inputs, str(log_dir / "inputs"), reload_input_dir
- )
- if reloaded_inputs is not None:
- return reloaded_inputs
+ save_tensors(inputs, str(log_dir / "inputs"))
return inputs
-def log_post_fwd_hook(
- module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any
-) -> None:
+def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
+ outputs: Any) -> None:
"""Hook to capture module outputs after forward pass.
Args:
@@ -491,12 +303,11 @@ def log_post_fwd_hook(
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
- dump_intermediates_to_json(outputs, log_dir / "outputs.json")
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
- assert intermediate_log_config is not None, "IL config should not be None"
+ assert intermediate_log_config is not None, \
+ "IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
- dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
@@ -532,14 +343,14 @@ class IntermediatesLogger:
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
- self.hooks: list[
- tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]]
- ] = []
+ self.hooks: list[tuple[str, str, Optional[RemovableHandle],
+ Optional[RemovableHandle]]] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
- logger.info("Intermediates will be logged in %s", config.output_run_dir)
+ logger.info("Intermediates will be logged in %s",
+ config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
@@ -551,13 +362,11 @@ class IntermediatesLogger:
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
- logger.debug(
- "Registered pre_fwd hook for %s", module.__class__.__name__
- )
+ logger.debug("Registered pre_fwd hook for %s",
+ module.__class__.__name__)
post_hook = module.register_forward_hook(log_post_fwd_hook)
- logger.debug(
- "Registered post_fwd hook for %s", module.__class__.__name__
- )
+ logger.debug("Registered post_fwd hook for %s",
+ module.__class__.__name__)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
@@ -578,7 +387,8 @@ class IntermediatesLogger:
def register_intermediate_hooks(
- model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs
+ model: torch.nn.Module,
+ config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
@@ -590,10 +400,6 @@ def register_intermediate_hooks(
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
- if config is None:
- # Create config from kwargs
- config = IntermediateLoggingConfig.from_dict(kwargs)
-
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index a0d5cd2e40..ec4261aa5a 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -27,12 +27,12 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
+from vllm.v1.intermediates.intermediates_logging import intermediate_logging
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
-from vllm.v1.intermediates.intermediates_logging import intermediate_logging
logger = init_logger(__name__)
diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py
index bbe601afa9..3c17a51be1 100644
--- a/vllm/v1/worker/worker_base.py
+++ b/vllm/v1/worker/worker_base.py
@@ -6,10 +6,11 @@ from typing import Optional
import torch
import torch.nn as nn
-from vllm.config import VllmConfig, IntermediateLoggingConfig
+from vllm.config import IntermediateLoggingConfig, VllmConfig
from vllm.logger import init_logger
+from vllm.v1.intermediates.intermediates_logging import (
+ register_intermediate_hooks)
from vllm.v1.kv_cache_interface import KVCacheSpec
-from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
@@ -64,27 +65,26 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
-
- def register_intermediate_hooks(self,
- config: Optional[IntermediateLoggingConfig] = None,
- **kwargs) -> None:
+
+ def register_intermediate_hooks(
+ self, config: Optional[IntermediateLoggingConfig] = None) -> None:
"""Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core.
- It registers hooks on the model to dump intermediate tensors during execution.
+ It registers hooks on the model to dump intermediate tensors during
+ execution.
Args:
- config: Configuration for intermediate logging. If provided, this takes precedence over kwargs.
+ config: Configuration for intermediate logging. If provided, this
+ takes precedence over kwargs.
"""
- if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None:
- logger.error("Could not register intermediate hooks: model_runner.model is not accessible")
+ if self.model_runner is None or not hasattr(
+ self.model_runner, "model") or self.model_runner.model is None:
+ logger.error("Could not register intermediate hooks: "
+ "model_runner.model is not accessible")
return
model = self.model_runner.model
try:
- # Register hooks
- register_intermediate_hooks(model, config, **kwargs)
- # Store the logger instance for potential later hook removal
- except Exception as e:
- logger.info("Successfully registered intermediate hooks")
- logger.error("Error registering intermediate hooks", exc_info=True)
-
+ register_intermediate_hooks(model, config)
+ except Exception:
+ logger.exception("Error registering intermediate hooks")
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 065bbd26c0..d7f50f713e 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -128,21 +128,21 @@ class WorkerBase:
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
-
- def register_intermediate_hooks(self, config=None, **kwargs) -> None:
+
+ def register_intermediate_hooks(self, config=None) -> None:
"""Register hooks for intermediate tensor logging.
- This method is a stub for v0 workers. The actual implementation is in v1 workers.
- It's included here for compatibility with the collective_rpc mechanism.
+ This method is a stub for v0 workers. The actual implementation is
+ in v1 workers. It's included here for compatibility with the
+ collective_rpc mechanism.
Args:
config: Configuration for intermediate logging.
- **kwargs: Configuration parameters for intermediate logging.
- These are ignored in v0 workers.
"""
logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. "
- "This is only available in v1 workers. No hooks will be registered.")
+ "This is only available in v1 workers. No hooks will be registered."
+ )
return None