From 2af83ebdde1cc03b147650cdbcf66a8e3814010b Mon Sep 17 00:00:00 2001 From: Lucia Fang Date: Mon, 28 Jul 2025 19:03:26 -0700 Subject: [PATCH] remove feature for metadata dump and input reload Signed-off-by: Lucia Fang --- docs/contributing/intermediate_logging.md | 1 - tests/v1/test_intermediates_logging.py | 19 +- tools/compare_intermediate.py | 384 ++++++++---------- vllm/config.py | 61 +-- vllm/engine/arg_utils.py | 39 +- .../v1/intermediates/intermediates_logging.py | 268 ++---------- vllm/v1/worker/gpu_worker.py | 2 +- vllm/v1/worker/worker_base.py | 34 +- vllm/worker/worker_base.py | 14 +- 9 files changed, 278 insertions(+), 544 deletions(-) diff --git a/docs/contributing/intermediate_logging.md b/docs/contributing/intermediate_logging.md index 4b1dc2aca8..fba4f439b6 100644 --- a/docs/contributing/intermediate_logging.md +++ b/docs/contributing/intermediate_logging.md @@ -49,7 +49,6 @@ The configuration file should be a JSON file with the following structure: python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json ``` - #### Configuration Parameters | Parameter | Type | Description | Default | diff --git a/tests/v1/test_intermediates_logging.py b/tests/v1/test_intermediates_logging.py index a25a70e910..9d2d0110e9 100644 --- a/tests/v1/test_intermediates_logging.py +++ b/tests/v1/test_intermediates_logging.py @@ -5,9 +5,8 @@ Tests for the intermediate tensor logging functionality. """ import json -from os.path import isdir -import shutil import os +import shutil import tempfile from pathlib import Path from unittest import mock @@ -17,14 +16,10 @@ import torch import torch.nn as nn from vllm.config import IntermediateLoggingConfig -from vllm.v1.intermediates.intermediates_logging import (get_current_il_config, - get_step, increment_step, - intermediate_logging, - register_intermediate_hooks, - reset_step, - should_log_device, - should_log_module, - should_log_step) +from vllm.v1.intermediates.intermediates_logging import ( + get_current_il_config, get_step, increment_step, intermediate_logging, + register_intermediate_hooks, reset_step, should_log_device, + should_log_module, should_log_step) class SimpleModel(nn.Module): @@ -237,7 +232,8 @@ def test_register_hooks(simple_model, il_config): assert len(logger_instance.hooks) == 0 -@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json') +@mock.patch( + 'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json') @mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors') def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model, il_config, temp_output_dir): @@ -262,7 +258,6 @@ def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model, # Check that dump_intermediates_to_json and save_tensors were called assert mock_dump_json.called assert mock_save_tensors.called - # Remove hooks logger_instance.remove_hooks() diff --git a/tools/compare_intermediate.py b/tools/compare_intermediate.py index 984c604503..833ac38aad 100755 --- a/tools/compare_intermediate.py +++ b/tools/compare_intermediate.py @@ -7,27 +7,30 @@ This script compares the tensor outputs from two different intermediate logging directories and generates a report of the differences. Usage: - python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options] + python compare_intermediate.py --dir1 /path/to/first/log/dir \ + --dir2 /path/to/second/log/dir [options] Options: --dir1 DIR First intermediate logging directory --dir2 DIR Second intermediate logging directory --output FILE Output file for the report (default: stdout) - --format {md,json} Output format (default: md) - --rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5) - --atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8) + --rtol FLOAT Relative tolerance for tensor comparison + (default: 1e-5) + --atol FLOAT Absolute tolerance for tensor comparison + (default: 1e-8) --steps STEPS Comma-separated list of steps to compare (default: all) - --modules MODULES Comma-separated list of module name patterns to compare (default: all) + --modules MODULES Comma-separated list of module name patterns to compare + (default: all) --verbose Include detailed information about each tensor """ import argparse import json -import re from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Optional +import regex as re import torch @@ -40,34 +43,19 @@ def load_tensor(path: Path) -> torch.Tensor: return None -def load_json(path: Path) -> Dict: - """Load a JSON file.""" - try: - with open(path, "r") as f: - return json.load(f) - except Exception as e: - print(f"Error loading JSON from {path}: {e}") - return {} - - -def extract_diff_metatada(exception_str: str) -> Dict: +def extract_diff_metatada(exception_str: str) -> dict: try: num_diff_elements = int( - re.search(r"Mismatched elements: (\d+) /", exception_str).group(1) - ) + re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)) total_elements = int( - re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1) - ) + re.search(r"Mismatched elements: \d+ / (\d+)", + exception_str).group(1)) max_abs_diff = float( - re.search( - r"Greatest absolute difference: ([\d\.e-]+)", exception_str - ).group(1) - ) + re.search(r"Greatest absolute difference: ([\d\.e-]+)", + exception_str).group(1)) max_rel_diff = float( - re.search( - r"Greatest relative difference: ([\d\.e-]+)", exception_str - ).group(1) - ) + re.search(r"Greatest relative difference: ([\d\.e-]+)", + exception_str).group(1)) return { "num_diff_elements": num_diff_elements, "total_elements": total_elements, @@ -78,9 +66,8 @@ def extract_diff_metatada(exception_str: str) -> Dict: return {"error": exception_str} -def compare_tensors( - tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float -) -> Dict: +def compare_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, + atol: float) -> dict: """Compare two tensors and return a dictionary with comparison results.""" if tensor1 is None or tensor2 is None: return {"match": False, "error": "One or both tensors are None"} @@ -105,60 +92,8 @@ def compare_tensors( return {"match": True} -def compare_json_values(value1: Any, value2: Any) -> Dict: - """Compare two JSON values and return a dictionary with comparison results.""" - if type(value1) is not type(value2): - return { - "match": False, - "error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}", - } - - if isinstance(value1, dict): - # Compare dictionaries - all_keys = set(value1.keys()) | set(value2.keys()) - mismatches = {} - - for key in all_keys: - if key not in value1: - mismatches[key] = {"error": "Missing in first dict"} - elif key not in value2: - mismatches[key] = {"error": "Missing in second dict"} - else: - comparison = compare_json_values(value1[key], value2[key]) - if not comparison["match"]: - mismatches[key] = comparison - - if mismatches: - return {"match": False, "mismatches": mismatches} - return {"match": True} - - elif isinstance(value1, list): - # Compare lists - if len(value1) != len(value2): - return { - "match": False, - "error": f"Length mismatch: {len(value1)} vs {len(value2)}", - } - - mismatches = {} - for i, (item1, item2) in enumerate(zip(value1, value2)): - comparison = compare_json_values(item1, item2) - if not comparison["match"]: - mismatches[i] = comparison - - if mismatches: - return {"match": False, "mismatches": mismatches} - return {"match": True} - - else: - # Compare primitive values - if value1 == value2: - return {"match": True} - else: - return {"match": False, "value1": value1, "value2": value2} - - -def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]: +def find_tensor_files( + directory: Path) -> dict[str, dict[str, dict[str, list[Path]]]]: """ Find all tensor files in the given directory. @@ -198,23 +133,14 @@ def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Pat if output_tensors: result[step_name][module_name]["outputs"] = output_tensors - # Find JSON metadata files - inputs_json = module_dir / "inputs.json" - if inputs_json.exists(): - result[step_name][module_name]["inputs_json"] = [inputs_json] - - outputs_json = module_dir / "outputs.json" - if outputs_json.exists(): - result[step_name][module_name]["outputs_json"] = [outputs_json] - return result def filter_steps_and_modules( - tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]], - steps: Optional[List[str]] = None, - module_patterns: Optional[List[str]] = None, -) -> Dict[str, Dict[str, Dict[str, List[Path]]]]: + tensor_files: dict[str, dict[str, dict[str, list[Path]]]], + steps: Optional[list[str]] = None, + module_patterns: Optional[list[str]] = None, +) -> dict[str, dict[str, dict[str, list[Path]]]]: """Filter tensor files by steps and module patterns.""" result = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) @@ -223,11 +149,13 @@ def filter_steps_and_modules( step_names = [f"step_{step}" for step in steps] steps_to_include = {step: True for step in step_names} else: - steps_to_include = {step: True for step in tensor_files.keys()} + steps_to_include = {step: True for step in tensor_files} # Compile module patterns if module_patterns: - compiled_patterns = [re.compile(pattern) for pattern in module_patterns] + compiled_patterns = [ + re.compile(pattern) for pattern in module_patterns + ] else: compiled_patterns = None @@ -237,11 +165,10 @@ def filter_steps_and_modules( for module_name, file_types in modules.items(): # Check if module matches any pattern - if compiled_patterns: - if not any( - pattern.search(module_name) for pattern in compiled_patterns - ): - continue + if compiled_patterns and not any( + pattern.search(module_name) + for pattern in compiled_patterns): + continue result[step_name][module_name] = file_types @@ -253,9 +180,9 @@ def compare_directories( dir2: Path, rtol: Optional[float] = None, atol: Optional[float] = None, - steps: Optional[List[str]] = None, - module_patterns: Optional[List[str]] = None, -) -> Dict: + steps: Optional[list[str]] = None, + module_patterns: Optional[list[str]] = None, +) -> dict: """Compare two intermediate logging directories and return a report.""" # Find tensor files in both directories tensor_files1 = find_tensor_files(dir1) @@ -263,8 +190,10 @@ def compare_directories( # Filter by steps and modules if steps or module_patterns: - tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns) - tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns) + tensor_files1 = filter_steps_and_modules(tensor_files1, steps, + module_patterns) + tensor_files2 = filter_steps_and_modules(tensor_files2, steps, + module_patterns) # Get all steps and modules from both directories all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys()) @@ -296,12 +225,12 @@ def compare_directories( # TODO: check if module calls txt exsits dir1_module_call_file = dir1 / step / "module_calls.txt" if dir1_module_call_file.exists(): - with open(dir1 / step / "module_calls.txt", "r") as f: + with open(dir1 / step / "module_calls.txt") as f: all_modules = f.read().splitlines() else: print( - "Warnings: the module call orders are missed, ordering using module alphbetics" - ) + "Warnings: the module call orders are missed, ordering using " + "module alphbetics") all_modules = sorted(set(modules1.keys()) | set(modules2.keys())) step_report["module_call_list"] = [] for module in all_modules: @@ -329,32 +258,17 @@ def compare_directories( step_report["modules"][module] = module_report continue - # Compare JSON metadata - for json_type in ["inputs_json", "outputs_json"]: - json_files1 = modules1[module].get(json_type, []) - json_files2 = modules2[module].get(json_type, []) - - if json_files1 and json_files2: - json1 = load_json(json_files1[0]) - json2 = load_json(json_files2[0]) - - json_comparison = compare_json_values(json1, json2) - json_name = json_type.replace("_json", "") - module_report[f"{json_name}_metadata"] = json_comparison - - # Add file paths for manual checking when there's a mismatch - if not json_comparison.get("match", True): - module_report[f"{json_name}_metadata"]["file1"] = str( - json_files1[0] - ) - module_report[f"{json_name}_metadata"]["file2"] = str( - json_files2[0] - ) - # Compare input tensors - input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])} - input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])} - all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys()) + input_tensors1 = { + p.name: p + for p in modules1[module].get("inputs", []) + } + input_tensors2 = { + p.name: p + for p in modules2[module].get("inputs", []) + } + all_input_names = set(input_tensors1.keys()) | set( + input_tensors2.keys()) for tensor_name in sorted(all_input_names): if tensor_name not in input_tensors1: @@ -389,9 +303,16 @@ def compare_directories( module_report["summary"]["total_tensors"] += 1 # Compare output tensors - output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])} - output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])} - all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys()) + output_tensors1 = { + p.name: p + for p in modules1[module].get("outputs", []) + } + output_tensors2 = { + p.name: p + for p in modules2[module].get("outputs", []) + } + all_output_names = set(output_tensors1.keys()) | set( + output_tensors2.keys()) for tensor_name in sorted(all_output_names): if tensor_name not in output_tensors1: @@ -439,64 +360,58 @@ def compare_directories( # Add overall summary report["summary"] = { - "total_steps": len(all_steps), - "total_modules": sum( - step_report["summary"]["total_modules"] - for step_report in report["steps"].values() - ), - "matching_modules": sum( - step_report["summary"]["matching_modules"] - for step_report in report["steps"].values() - ), - "mismatched_modules": sum( - step_report["summary"]["mismatched_modules"] - for step_report in report["steps"].values() - ), - "missing_modules": sum( - step_report["summary"]["missing_modules"] - for step_report in report["steps"].values() - ), - "total_tensors": sum( - module_report["summary"]["total_tensors"] + "total_steps": + len(all_steps), + "total_modules": + sum(step_report["summary"]["total_modules"] + for step_report in report["steps"].values()), + "matching_modules": + sum(step_report["summary"]["matching_modules"] + for step_report in report["steps"].values()), + "mismatched_modules": + sum(step_report["summary"]["mismatched_modules"] + for step_report in report["steps"].values()), + "missing_modules": + sum(step_report["summary"]["missing_modules"] + for step_report in report["steps"].values()), + "total_tensors": + sum(module_report["summary"]["total_tensors"] for step_report in report["steps"].values() for module_name, module_report in step_report["modules"].items() - if "summary" in module_report - ), - "matching_tensors": sum( - module_report["summary"]["matching_tensors"] + if "summary" in module_report), + "matching_tensors": + sum(module_report["summary"]["matching_tensors"] for step_report in report["steps"].values() for module_name, module_report in step_report["modules"].items() - if "summary" in module_report - ), - "mismatched_tensors": sum( - module_report["summary"]["mismatched_tensors"] + if "summary" in module_report), + "mismatched_tensors": + sum(module_report["summary"]["mismatched_tensors"] for step_report in report["steps"].values() for module_name, module_report in step_report["modules"].items() - if "summary" in module_report - ), - "missing_tensors": sum( - module_report["summary"]["missing_tensors"] + if "summary" in module_report), + "missing_tensors": + sum(module_report["summary"]["missing_tensors"] for step_report in report["steps"].values() for module_name, module_report in step_report["modules"].items() - if "summary" in module_report - ), + if "summary" in module_report), } return report -def generate_markdown_report(report: Dict, verbose: bool = False) -> str: +def generate_markdown_report(report: dict, verbose: bool = False) -> str: """Generate a markdown report from the comparison results.""" lines = [] # Add header lines.append("# Intermediate Logging Comparison Report") lines.append("") - lines.append("Comparing intermediate logging outputs between:") + lines.append("Comparing intermediate logging outputs " + "between:") lines.append(f"- **Directory 1**: `{report['dir1']}`") lines.append(f"- **Directory 2**: `{report['dir2']}`") lines.append("") - lines.append(f"Comparison parameters:") + lines.append("Comparison parameters:") lines.append(f"- Relative tolerance (rtol): {report['rtol']}") lines.append(f"- Absolute tolerance (atol): {report['atol']}") lines.append("") @@ -509,11 +424,13 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: lines.append("|----------|-------|----------|------------|---------|") lines.append(f"| Steps | {summary['total_steps']} | - | - | - |") lines.append( - f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |" - ) + f"| Modules | {summary['total_modules']} | " + f"{summary['matching_modules']} | {summary['mismatched_modules']} | " + f"{summary['missing_modules']} |") lines.append( - f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |" - ) + f"| Tensors | {summary['total_tensors']} | " + f"{summary['matching_tensors']} | {summary['mismatched_tensors']} | " + f"{summary['missing_tensors']} |") lines.append("") # Add step details @@ -523,8 +440,9 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: lines.append(f"## {step_name}") lines.append("") lines.append( - f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules" - ) + f"**Summary**: {step_summary['matching_modules']} matching " + f"modules, {step_summary['mismatched_modules']} mismatched " + f"modules, {step_summary['missing_modules']} missing modules") lines.append("") # Add module details @@ -540,15 +458,14 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: module_summary = module_report["summary"] # Determine module status - if module_summary["mismatched_tensors"] > 0: - status = "❌" - else: - status = "✅" + status = "❌" if module_summary["mismatched_tensors"] > 0 else "✅" lines.append(f"### {status} {module_name}") lines.append("") lines.append( - f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors" + f"**Summary**: {module_summary['matching_tensors']} matching " + f"tensors, {module_summary['mismatched_tensors']} mismatched " + f"tensors, {module_summary['missing_tensors']} missing tensors" ) lines.append("") @@ -558,20 +475,21 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: metadata_comparison = module_report[metadata_type] if not metadata_comparison.get("match", True): file_paths = "" - if ( - "file1" in metadata_comparison - and "file2" in metadata_comparison - ): - file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`" + if ("file1" in metadata_comparison + and "file2" in metadata_comparison): + file_paths = ( + f" - Files: " + f"`{metadata_comparison['file1']}` " + f"vs `{metadata_comparison['file2']}`") lines.append( - f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}" - ) + f"**{metadata_type.capitalize()}**: Mismatch " + f"detected{file_paths}") if verbose and "mismatches" in metadata_comparison: lines.append("```json") lines.append( - json.dumps(metadata_comparison["mismatches"], indent=2) - ) + json.dumps(metadata_comparison["mismatches"], + indent=2)) lines.append("```") lines.append("") @@ -585,8 +503,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: lines.append("|--------|--------|---------|") for tensor_name, comparison in sorted( - module_report["inputs"].items() - ): + module_report["inputs"].items()): if comparison.get("match", False): status = "✅" details = "Tensors match" @@ -595,13 +512,23 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: details = comparison["error"] else: status = "❌" - details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, " - details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, " - details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" + details = ( + f"Max abs diff: " + f"{comparison.get('max_abs_diff', 'N/A')}, ") + details += ( + f"Max relative diff: " + f"{comparison.get('max_rel_diff', 'N/A')}, ") + details += ( + f"Diff elements: " + f"{comparison.get('num_diff_elements', 'N/A')}/" + f"{comparison.get('total_elements', 'N/A')}") if "file1" in comparison and "file2" in comparison: - details += f"
Files: `{comparison['file1']}` vs `{comparison['file2']}`" + details += ( + f"
Files: `{comparison['file1']}` vs " + f"`{comparison['file2']}`") - lines.append(f"| {tensor_name} | {status} | {details} |") + lines.append( + f"| {tensor_name} | {status} | {details} |") lines.append("") @@ -613,8 +540,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: lines.append("|--------|--------|---------|") for tensor_name, comparison in sorted( - module_report["outputs"].items() - ): + module_report["outputs"].items()): if comparison.get("match", False): status = "✅" details = "Tensors match" @@ -623,11 +549,19 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: details = comparison["error"] else: status = "❌" - details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, " - details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, " - details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" + details = ( + f"Max abs diff: " + f"{comparison.get('max_abs_diff', 'N/A')}, ") + details += ( + f"Max relative diff: " + f"{comparison.get('max_rel_diff', 'N/A')}, ") + details += ( + f"Diff elements: " + f"{comparison.get('num_diff_elements', 'N/A')}/" + f"{comparison.get('total_elements', 'N/A')}") - lines.append(f"| {tensor_name} | {status} | {details} |") + lines.append( + f"| {tensor_name} | {status} | {details} |") lines.append("") @@ -636,15 +570,16 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str: def main(): parser = argparse.ArgumentParser( - description="Compare intermediate logging outputs from two different runs." - ) - parser.add_argument( - "--dir1", required=True, help="First intermediate logging directory" - ) - parser.add_argument( - "--dir2", required=True, help="Second intermediate logging directory" - ) - parser.add_argument("--output", help="Output file for the report (default: stdout)") + description= + "Compare intermediate logging outputs from two different runs.") + parser.add_argument("--dir1", + required=True, + help="First intermediate logging directory") + parser.add_argument("--dir2", + required=True, + help="Second intermediate logging directory") + parser.add_argument("--output", + help="Output file for the report (default: stdout)") parser.add_argument( "--rtol", type=float, @@ -658,11 +593,12 @@ def main(): help="Absolute tolerance for tensor comparison (default: 1e-8)", ) parser.add_argument( - "--steps", help="Comma-separated list of steps to compare (default: all)" - ) + "--steps", + help="Comma-separated list of steps to compare (default: all)") parser.add_argument( "--modules", - help="Comma-separated list of module name patterns to compare (default: all)", + help="Comma-separated list of module name patterns to compare " + "(default: all)", ) parser.add_argument( "--verbose", diff --git a/vllm/config.py b/vllm/config.py index ab5676686d..1c6d4a87d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,8 +17,7 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, from functools import cached_property from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args, List, Set) -from re import Pattern + Protocol, TypeVar, Union, cast, get_args) import regex as re import torch @@ -4026,63 +4025,65 @@ class KVEventsConfig: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class IntermediateLoggingConfig: """Configuration for intermediate tensor logging.""" - + output_dir: str = "/tmp/vllm_intermediates" """Directory where to save the intermediate tensors.""" - + reload_input_dir: Optional[str] = None """Directory where to load the inputs for the steps/modules. This is used when we want to check per module numerical gaps instead of accumulated gap to further dive into the actual numerical issues.""" - module_call_match: Optional[List[str]] = None + module_call_match: Optional[list[str]] = None """Match modules by name regex and call index ( a module can be called multiple times in a step) List of regex:call_idx, call_idx is -1 for default for all calls """ - - log_step_ids: List[int] = field(default_factory=lambda: [0]) + + log_step_ids: list[int] = field(default_factory=lambda: [0]) """List of step IDs to log (empty list means log all steps).""" - + log_post_fwd_inputs: bool = False """Whether logging inputs after forwards for each module""" max_tensor_size: Optional[int] = None """Maximum number of elements in tensors to log (None = no limit).""" - + enabled: bool = True """Whether logging is enabled.""" - device_names: List[str] = field(default_factory=list) + device_names: list[str] = field(default_factory=list) """List of device names to log (empty list means log all devices).""" - - _compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False) + + _compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict, + init=False) """Compiled regex patterns for module filtering.""" - + _module_call: dict[str, int] = field(default_factory=dict, init=False) - _step_id_set: Set[int] = field(default_factory=set, init=False) + _step_id_set: set[int] = field(default_factory=set, init=False) """Set of step IDs for faster lookup.""" _output_run_dir: str = "/tmp/vllm_intermediates" """Unique directory to save single run/serve logging result.""" - + def __post_init__(self): """Initialize derived fields after instance creation.""" self._compile_regex_patterns() self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4()) self._step_id_set = set(self.log_step_ids) - + def _compile_regex_patterns(self): """Compile regex patterns for module name filtering.""" from vllm.logger import init_logger logger = init_logger(__name__) - + self._compiled_module_matches = [] - + if self.module_call_match is None: - logger.info("No module name regex patterns provided, will log all modules") + logger.info( + "No module name regex patterns provided, will log all modules") return - + # Compile all patterns for regex_pattern_call_idx in self.module_call_match: try: @@ -4091,15 +4092,16 @@ class IntermediateLoggingConfig: call_idx = -1 if len(splits) > 1: call_idx = int(splits[1]) - compiled_pattern: Pattern[str] = re.compile(regex_pattern) + compiled_pattern: re.Pattern[str] = re.compile(regex_pattern) self._compiled_module_calls[compiled_pattern] = call_idx - logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'") + logger.info("Successfully compiled regex pattern: '%s'", + regex_pattern) except Exception as e: - logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") - raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e + logger.error("Failed to parse module_call_match '%s': %s", + regex_pattern_call_idx, e) - - logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns") + logger.info("Compiled %d regex patterns", + len(self._compiled_module_calls)) def to_dict(self) -> dict: """Convert the config to a dictionary for serialization.""" @@ -4111,12 +4113,12 @@ class IntermediateLoggingConfig: "enabled": self.enabled, "device_names": self.device_names } - + @classmethod def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig": """Parse the CLI value for the speculative config.""" return cls(**dict_value) - + @property def output_run_dir(self) -> str: return self._output_run_dir @@ -4138,7 +4140,6 @@ class IntermediateLoggingConfig: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - class CompilationLevel: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 49e331e183..547f1bb810 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,13 +27,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, HfOverrides, IntermediateLoggingConfig, - KVEventsConfig, KVTransferConfig, - LoadConfig, LogprobsMode, LoRAConfig, ModelConfig, - ModelDType, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, - SchedulerPolicy, SpeculativeConfig, TaskOption, - TokenizerMode, VllmConfig, get_attr_docs, get_field) + KVEventsConfig, KVTransferConfig, LoadConfig, + LogprobsMode, LoRAConfig, ModelConfig, ModelDType, + ModelImpl, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, + RunnerOption, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -400,7 +400,7 @@ class EngineArgs: str] = ModelConfig.logits_processor_pattern speculative_config: Optional[Dict[str, Any]] = None - + show_hidden_metrics_for_version: Optional[str] = \ ObservabilityConfig.show_hidden_metrics_for_version @@ -773,10 +773,13 @@ class EngineArgs: default=None, help="The configurations for intermediate loggings. Should be a " "JSON string.") - - intermediate_log_group.add_argument("--intermediate-log-config-path", type=str, - help="The path to the configurations for intermediate loggings. Should be a string.") - + + intermediate_log_group.add_argument( + "--intermediate-log-config-path", + type=str, + help="The path to the configurations for intermediate loggings. " + "Should be a string.") + # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) observability_group = parser.add_argument_group( @@ -865,9 +868,6 @@ class EngineArgs: vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) - - - # Other arguments parser.add_argument('--disable-log-stats', action='store_true', @@ -979,11 +979,9 @@ class EngineArgs: use_tqdm_on_load=self.use_tqdm_on_load, pt_load_map_location=self.pt_load_map_location, ) - def create_intermediate_log_config( - self, - ) -> Optional[IntermediateLoggingConfig]: + self, ) -> Optional[IntermediateLoggingConfig]: """Initializes and returns an IntermediateLoggingConfig object based on `intermediate_log_config` or `intermediate_log_config_path`. """ @@ -991,7 +989,7 @@ class EngineArgs: return IntermediateLoggingConfig.from_dict( self.intermediate_log_config) if self.intermediate_log_config_path is not None: - with open(self.intermediate_log_config_path, "r") as f: + with open(self.intermediate_log_config_path) as f: return IntermediateLoggingConfig.from_dict(json.load(f)) return None @@ -1235,8 +1233,7 @@ class EngineArgs: disable_log_stats=self.disable_log_stats, ) - intermediate_log_config = self.create_intermediate_log_config( - ) + intermediate_log_config = self.create_intermediate_log_config() # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid diff --git a/vllm/v1/intermediates/intermediates_logging.py b/vllm/v1/intermediates/intermediates_logging.py index 0024523b7b..ffa19ae221 100644 --- a/vllm/v1/intermediates/intermediates_logging.py +++ b/vllm/v1/intermediates/intermediates_logging.py @@ -7,8 +7,6 @@ This module provides functionality to capture and save intermediate tensors (inputs and outputs) from PyTorch modules during forward passes. """ -import json -import os from contextlib import contextmanager from pathlib import Path from typing import Any, Optional @@ -17,8 +15,6 @@ import torch from torch.utils.hooks import RemovableHandle from vllm.config import IntermediateLoggingConfig - -# Import logger from vllm from vllm.logger import init_logger logger = init_logger(__name__) @@ -91,7 +87,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool: return False # If no patterns are defined, log all modules if not config._compiled_module_calls: - logger.debug("No patterns defined, will log module: %s", module_name) set_il_module_name(module, module_name) set_il_module_call_idx(module, -1) return True @@ -115,7 +110,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool: def is_log_enabled(config): if not config or not config.enabled: - logger.debug("Not logging because config not enabled") return False if torch.compiler.is_compiling(): logger.debug("Not logging because torch.compile is in progress") @@ -161,151 +155,7 @@ def get_current_il_config(): return _global_config -def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any: - try: - # Convert inputs to JSON-serializable format - intermediates_json = convert_intermediates_to_json(intermediates) - with open(path, "w") as f: - json.dump(intermediates_json, f, indent=2) - logger.debug("Saved all intermediates as JSON to %s", path) - except Exception as e: - logger.warning("Failed to save intermediates as JSON: %s", e) - import traceback - - logger.warning(traceback.format_exc()) - - -def convert_intermediates_to_json(tensor: Any) -> Any: - """Convert a intermediates(including tensor) to a JSON-serializable - representation. - - Args: - intermediates: The intermediates to convert. - - Returns: - A JSON-serializable representation of the tensor. - """ - if isinstance(tensor, torch.Tensor): - try: - result = { - "type": "tensor", - "shape": list(tensor.shape), - "dtype": str(tensor.dtype), - "numel": tensor.numel(), - } - - return result - except Exception as e: - # Handle any errors in tensor conversion - return { - "type": "tensor_error", - "error": str(e), - "tensor_type": str(type(tensor)), - } - - elif isinstance(tensor, (list, tuple)): - # For lists/tuples, recursively convert each element - container_type = "list" if isinstance(tensor, list) else "tuple" - - # If it's a large list, only include a sample - if len(tensor) > 20: - return { - "type": container_type, - "length": len(tensor), - "sample": [ - convert_intermediates_to_json(item) for item in tensor[:100] - ], - "note": f"Showing only first 20 of {len(tensor)} items", - } - else: - return { - "type": container_type, - "items": [convert_intermediates_to_json(item) for item in tensor], - } - - elif isinstance(tensor, dict): - # For dictionaries, recursively convert each value - if len(tensor) > 20: - # For large dicts, only include keys and a sample of values - keys = list(tensor.keys()) - sample_keys = keys[:20] - return { - "type": "dict", - "length": len(tensor), - "keys": keys, - "sample": { - k: convert_intermediates_to_json(tensor[k]) for k in sample_keys - }, - "note": f"Showing only first 20 of {len(tensor)} items", - } - else: - return { - "type": "dict", - "items": { - k: convert_intermediates_to_json(v) for k, v in tensor.items() - }, - } - - elif tensor is None: - return None - - elif isinstance(tensor, (int, float, bool, str)): - # Primitive types can be directly serialized - return tensor - - else: - # For other types, use string representation - return {"type": str(type(tensor).__name__), "string_repr": str(tensor)} - - -def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool: - """Utility function to dump tensor metadata to a file. - - Args: - tensor: The tensor to dump. - file_path: Base path where to save the tensor (without extension). - """ - intermediate_log_config = get_current_il_config() - if intermediate_log_config is None: - return False - if ( - intermediate_log_config.max_tensor_size is not None - and tensor.numel() > intermediate_log_config.max_tensor_size - ): - # Save tensor metadata instead of full tensor - tensor_info = { - "shape": list(tensor.shape), - "dtype": str(tensor.dtype), - "device": str(tensor.device), - "numel": tensor.numel(), - "skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size " - f"{intermediate_log_config.max_tensor_size}", - } - os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True) - with open(f"{file_path}.json", "w") as f: - json.dump(tensor_info, f, indent=2) - return True - return False - - -def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any: - if reload_dir is None: - return None - try: - intermediate_log_config = get_current_il_config() - assert intermediate_log_config is not None - replace_dir = str(intermediate_log_config.output_run_dir) - reload_path = save_path.replace(replace_dir, reload_dir) - logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path) - return torch.load(reload_path, map_location=tensor.device) - except Exception as e: - logger.warning("Failed to load tensor from %s: %s", reload_dir, e) - return tensor - - -def save_tensors( - tensor: Any, file_path: str, reload_input_dir: Optional[str] = None -) -> Any: +def save_tensors(tensor: Any, file_path: str) -> Any: """Utility function to dump tensor to a file. Args: @@ -314,52 +164,32 @@ def save_tensors( file_path: Base path where to save the tensor (without extension). """ - # Also save the actual tensor data for tensors if isinstance(tensor, torch.Tensor): - # Check if tensor is too large - if save_tensors_metadata_if_too_large(tensor, file_path): - return - # Get device name device_name = str(tensor.device) - # Skip if device filtering is enabled and this device should not be - # logged intermediate_log_config = get_current_il_config() if not should_log_device(intermediate_log_config, device_name): - logger.debug( - "Skipping tensor on device %s due to device filter", device_name - ) return tensor - # Append device name to file path pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt" try: - # Save tensor directly without detaching or moving to CPU torch.save(tensor, pt_path) - reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir) - if reloaded_tensor is not None: - return reloaded_tensor - logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path) + logger.debug("Saved tensor of shape %s to %s", tensor.shape, + pt_path) except Exception as e: logger.warning("Failed to save tensor to %s: %s", pt_path, e) return tensor if isinstance(tensor, (list, tuple)): - # For collections, also save each item individually - - reloaded_inputs = [] for i, item in enumerate(tensor): - reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir) - reloaded_inputs.append(reloaded) - return tuple(reloaded_inputs) if reloaded_inputs else tensor + save_tensors(item, f"{file_path}_{i}") + return tensor if isinstance(tensor, dict): - reloaded_inputs = {} - # For dictionaries, also save each value individually for k, v in tensor.items(): - reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir) - reloaded_inputs[k] = reloaded - return reloaded_inputs if reloaded_inputs else tensor + save_tensors(v, f"{file_path}_{k}") + return tensor -def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None: +def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], + outputs: Any) -> None: """Hook to increment the global step counter after a forward pass. Args: @@ -381,7 +211,8 @@ def _prepare_module_log_dir( is_pre_fwd: bool = False, ) -> Path: # Create a unique directory for this step if not - dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}" + dump_dir = Path( + intermediate_log_config.output_run_dir) / f"step_{get_step()}" dump_dir.mkdir(exist_ok=True, parents=True) # Create module directory @@ -393,7 +224,8 @@ def _prepare_module_log_dir( if is_pre_fwd: _log_module_call(intermediate_log_config, module_name + suffix) module_dir.mkdir(exist_ok=True, parents=True) - logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir) + logger.debug("Logging module %s inputs/outputs to %s", module_name, + module_dir) return module_dir @@ -401,13 +233,8 @@ def _log_module_call( intermediate_log_config: IntermediateLoggingConfig, module_name: str, ) -> None: - logger.debug("Logging module call for %s", module_name) - # write module name and call to step: - file = ( - Path(intermediate_log_config.output_run_dir) - / f"step_{get_step()}" - / "module_calls.txt" - ) + file = (Path(intermediate_log_config.output_run_dir) / + f"step_{get_step()}" / "module_calls.txt") with open(file, "a") as f: f.write(f"{module_name}\n") @@ -425,7 +252,8 @@ def get_current_step_module_call(module_name: str) -> int: return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0) -def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]: +def prepare_log_current_fwd(module, + is_pre_fwd: bool = False) -> Optional[Path]: intermediate_log_config = get_current_il_config() if intermediate_log_config is None or not intermediate_log_config.enabled: return None @@ -443,15 +271,14 @@ def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]: if is_pre_fwd: update_current_step_module_call(module_name) if should_log: - log_dir = _prepare_module_log_dir( - intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd - ) + log_dir = _prepare_module_log_dir(intermediate_log_config, + module_name, + is_pre_fwd=is_pre_fwd) return log_dir -def log_pre_fwd_hook( - module: torch.nn.Module, inputs: tuple[Any, ...] -) -> tuple[Any, ...]: +def log_pre_fwd_hook(module: torch.nn.Module, + inputs: tuple[Any, ...]) -> tuple[Any, ...]: """Hook to capture module inputs before forward pass. Args: @@ -462,27 +289,12 @@ def log_pre_fwd_hook( The unchanged inputs. """ if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True): - dump_intermediates_to_json(inputs, log_dir / "inputs.json") - intermediate_log_config = get_current_il_config() - if intermediate_log_config is not None: - reload_input_dir = getattr( - intermediate_log_config, - "reload_input_dir", - "/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5", - ) - else: - reload_input_dir = None - reloaded_inputs = save_tensors( - inputs, str(log_dir / "inputs"), reload_input_dir - ) - if reloaded_inputs is not None: - return reloaded_inputs + save_tensors(inputs, str(log_dir / "inputs")) return inputs -def log_post_fwd_hook( - module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any -) -> None: +def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...], + outputs: Any) -> None: """Hook to capture module outputs after forward pass. Args: @@ -491,12 +303,11 @@ def log_post_fwd_hook( outputs: The outputs from the module's forward function. """ if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False): - dump_intermediates_to_json(outputs, log_dir / "outputs.json") save_tensors(outputs, str(log_dir / "outputs")) intermediate_log_config = get_current_il_config() - assert intermediate_log_config is not None, "IL config should not be None" + assert intermediate_log_config is not None, \ + "IL config should not be None" if intermediate_log_config.log_post_fwd_inputs: - dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json") save_tensors(inputs, str(log_dir / "post_fwd_inputs")) @@ -532,14 +343,14 @@ class IntermediatesLogger: def __init__(self, config: IntermediateLoggingConfig): self.config = config - self.hooks: list[ - tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]] - ] = [] + self.hooks: list[tuple[str, str, Optional[RemovableHandle], + Optional[RemovableHandle]]] = [] logger.debug("Created IntermediatesLogger with config: %s", config) path = Path(config.output_run_dir) path.mkdir(exist_ok=True, parents=True) # Log configuration - logger.info("Intermediates will be logged in %s", config.output_run_dir) + logger.info("Intermediates will be logged in %s", + config.output_run_dir) def register_hooks(self, model: torch.nn.Module) -> None: """Register hooks for the model. @@ -551,13 +362,11 @@ class IntermediatesLogger: for name, module in model.named_modules(): if name and should_log_module(self.config, name, module): pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook) - logger.debug( - "Registered pre_fwd hook for %s", module.__class__.__name__ - ) + logger.debug("Registered pre_fwd hook for %s", + module.__class__.__name__) post_hook = module.register_forward_hook(log_post_fwd_hook) - logger.debug( - "Registered post_fwd hook for %s", module.__class__.__name__ - ) + logger.debug("Registered post_fwd hook for %s", + module.__class__.__name__) self.hooks.append((name, module, pre_hook, post_hook)) # Register a step counter hook for the root model @@ -578,7 +387,8 @@ class IntermediatesLogger: def register_intermediate_hooks( - model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs + model: torch.nn.Module, + config: Optional[IntermediateLoggingConfig] = None ) -> IntermediatesLogger: """Register hooks to log intermediate tensors for a model. @@ -590,10 +400,6 @@ def register_intermediate_hooks( Returns: An IntermediatesLogger instance that can be used to manage the hooks. """ - if config is None: - # Create config from kwargs - config = IntermediateLoggingConfig.from_dict(kwargs) - logger_instance = IntermediatesLogger(config) logger_instance.register_hooks(model) return logger_instance diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a0d5cd2e40..ec4261aa5a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -27,12 +27,12 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.intermediates.intermediates_logging import intermediate_logging from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase -from vllm.v1.intermediates.intermediates_logging import intermediate_logging logger = init_logger(__name__) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index bbe601afa9..3c17a51be1 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -6,10 +6,11 @@ from typing import Optional import torch import torch.nn as nn -from vllm.config import VllmConfig, IntermediateLoggingConfig +from vllm.config import IntermediateLoggingConfig, VllmConfig from vllm.logger import init_logger +from vllm.v1.intermediates.intermediates_logging import ( + register_intermediate_hooks) from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 logger = init_logger(__name__) @@ -64,27 +65,26 @@ class WorkerBase(WorkerBaseV0): def check_health(self) -> None: """Basic health check (override for device-specific checks).""" return - - def register_intermediate_hooks(self, - config: Optional[IntermediateLoggingConfig] = None, - **kwargs) -> None: + + def register_intermediate_hooks( + self, config: Optional[IntermediateLoggingConfig] = None) -> None: """Register hooks for intermediate tensor logging. This method is called via collective_rpc from the engine core. - It registers hooks on the model to dump intermediate tensors during execution. + It registers hooks on the model to dump intermediate tensors during + execution. Args: - config: Configuration for intermediate logging. If provided, this takes precedence over kwargs. + config: Configuration for intermediate logging. If provided, this + takes precedence over kwargs. """ - if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None: - logger.error("Could not register intermediate hooks: model_runner.model is not accessible") + if self.model_runner is None or not hasattr( + self.model_runner, "model") or self.model_runner.model is None: + logger.error("Could not register intermediate hooks: " + "model_runner.model is not accessible") return model = self.model_runner.model try: - # Register hooks - register_intermediate_hooks(model, config, **kwargs) - # Store the logger instance for potential later hook removal - except Exception as e: - logger.info("Successfully registered intermediate hooks") - logger.error("Error registering intermediate hooks", exc_info=True) - + register_intermediate_hooks(model, config) + except Exception: + logger.exception("Error registering intermediate hooks") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 065bbd26c0..d7f50f713e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -128,21 +128,21 @@ class WorkerBase: def vocab_size(self) -> int: """Get vocabulary size from model configuration.""" return self.model_config.get_vocab_size() - - def register_intermediate_hooks(self, config=None, **kwargs) -> None: + + def register_intermediate_hooks(self, config=None) -> None: """Register hooks for intermediate tensor logging. - This method is a stub for v0 workers. The actual implementation is in v1 workers. - It's included here for compatibility with the collective_rpc mechanism. + This method is a stub for v0 workers. The actual implementation is + in v1 workers. It's included here for compatibility with the + collective_rpc mechanism. Args: config: Configuration for intermediate logging. - **kwargs: Configuration parameters for intermediate logging. - These are ignored in v0 workers. """ logger.warning( "register_intermediate_hooks is not implemented in v0 workers. " - "This is only available in v1 workers. No hooks will be registered.") + "This is only available in v1 workers. No hooks will be registered." + ) return None