Compare commits

..

7 Commits

Author SHA1 Message Date
e793b9a70c Merge remote-tracking branch 'origin/main' into il_tool
Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-08 17:33:55 -07:00
76c9ec0ddf adjust config type and remove config path for simplicity
Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-08 17:23:15 -07:00
87c737016d Merge remote-tracking branch 'origin/main' into il_tool
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:48:28 -07:00
ba90794ff1 remove feature for il_tool_compare
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:47:16 -07:00
ab4ab0fd28 address arg utils fix
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:45:13 -07:00
2af83ebdde remove feature for metadata dump and input reload
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:25:17 -07:00
d8bff253d7 add il tool
more changes

Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

fix tp

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison tool

tmp

add unit test and fix format

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison script and documentation

Signed-off-by: Lu Fang <fanglu@fb.com>

provide default intermediate logging

Signed-off-by: Lu Fang <fanglu@fb.com>

optional register il

Signed-off-by: Lu Fang <fanglu@fb.com>

add input reload and improve intermediate compare
2025-07-28 18:32:10 -07:00
25 changed files with 1054 additions and 448 deletions

View File

@ -149,25 +149,3 @@ steps:
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
env:
DOCKER_BUILDKIT: "1"
- label: "Build and publish nightly multi-arch image to DockerHub"
depends_on:
- create-multi-arch-manifest
if: build.env("NIGHTLY") == "1"
agents:
queue: cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly"
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
- "docker push vllm/vllm-openai:nightly"
- "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
# Clean up old nightly builds (keep only last 14)
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
plugins:
- docker-login#v3.0.0:
username: vllmbot
password-env: DOCKERHUB_TOKEN
env:
DOCKER_BUILDKIT: "1"

View File

@ -1,97 +0,0 @@
#!/bin/bash
set -ex
# Clean up old nightly builds from DockerHub, keeping only the last 14 builds
# This script uses DockerHub API to list and delete old tags with "nightly-" prefix
# DockerHub API endpoint for vllm/vllm-openai repository
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
# Get DockerHub token from environment
if [ -z "$DOCKERHUB_TOKEN" ]; then
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
exit 1
fi
# Function to get all tags from DockerHub
get_all_tags() {
local page=1
local all_tags=""
while true; do
local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \
"$REPO_API_URL?page=$page&page_size=100")
# Get both last_updated timestamp and tag name, separated by |
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
if [ -z "$tags" ]; then
break
fi
all_tags="$all_tags$tags"$'\n'
page=$((page + 1))
done
# Sort by timestamp (newest first) and extract just the tag names
echo "$all_tags" | sort -r | cut -d'|' -f2
}
delete_tag() {
local tag_name="$1"
echo "Deleting tag: $tag_name"
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url")
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
else
echo "Successfully deleted tag: $tag_name"
fi
}
# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first)
echo "Fetching all tags from DockerHub..."
all_tags=$(get_all_tags)
if [ -z "$all_tags" ]; then
echo "No tags found to clean up"
exit 0
fi
# Count total tags
total_tags=$(echo "$all_tags" | wc -l)
echo "Found $total_tags tags"
# Keep only the last 14 builds (including the current one)
tags_to_keep=14
tags_to_delete=$((total_tags - tags_to_keep))
if [ $tags_to_delete -le 0 ]; then
echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)"
exit 0
fi
echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep"
# Get tags to delete (skip the first $tags_to_keep tags)
tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1)))
if [ -z "$tags_to_delete_list" ]; then
echo "No tags to delete"
exit 0
fi
# Delete old tags
echo "Deleting old tags..."
while IFS= read -r tag; do
if [ -n "$tag" ]; then
delete_tag "$tag"
# Add a small delay to avoid rate limiting
sleep 1
fi
done <<< "$tags_to_delete_list"
echo "Cleanup completed successfully"

14
.github/mergify.yml vendored
View File

@ -273,20 +273,6 @@ pull_request_rules:
users:
- "sangstar"
- name: assign reviewer for modelopt changes
conditions:
- or:
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
- files~=^tests/models/quantization/test_modelopt\.py$
- files~=^tests/quantization/test_modelopt\.py$
- files~=^tests/models/quantization/test_nvfp4\.py$
- files~=^docs/features/quantization/modelopt\.md$
actions:
assign:
users:
- "Edwardf0t1"
- name: remove 'needs-rebase' label when conflict is resolved
conditions:
- -conflict

View File

@ -16,7 +16,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: '3.12'

View File

@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: "3.12"
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"

View File

@ -0,0 +1,65 @@
# Intermediate Tensor Logging
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
## Overview
The intermediate tensor logging feature enables you to:
- Log input and output tensors from a configured set of filters
- Filter modules by name using regex patterns
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
- Filter tensors by device
- Filter whole model fwd step id
## Usage
### Enabling via parameters or config file
**Offline Inference example**
Dump all modules, all devices for step 0 (default behavior)
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
```
Dump first layers module, all devices for step 0
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
```
#### Configuration Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
| `log_step_ids` | array | List of step IDs to log | `[0]` |
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
| `device_names` | array | List of device names to log | `[]` (log all devices) |
### Output Directory Structure
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
```
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
└── step_0
├── model.embed_tokens
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
├── model.layers.0.input_layernorm
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
└── step_1/
└── ...
```
Each tensor is saved in a `.pt` file containing the full PyTorch tensors (can be loaded with `torch.load()`)

View File

@ -0,0 +1,320 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the intermediate tensor logging functionality.
"""
import json
import os
import shutil
import tempfile
from pathlib import Path
from unittest import mock
import pytest
import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (
get_current_il_config, get_step, increment_step, intermediate_logging,
register_intermediate_hooks, reset_step, should_log_device,
should_log_module, should_log_step)
class SimpleModel(nn.Module):
"""A simple model for testing."""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
@pytest.fixture
def temp_output_dir():
"""Create a temporary directory for test outputs."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir)
@pytest.fixture
def simple_model():
"""Create a simple model for testing."""
return SimpleModel()
@pytest.fixture
def il_config(temp_output_dir):
"""Create a basic IntermediateLoggingConfig for testing."""
return IntermediateLoggingConfig(output_dir=temp_output_dir,
enabled=True,
log_step_ids=[0, 1],
module_call_match=[".*linear.*"])
def test_step_counter():
"""Test the step counter functionality."""
# Reset the step counter
reset_step()
assert get_step() == 0
# Increment the step counter
increment_step()
assert get_step() == 1
# Increment again
increment_step()
assert get_step() == 2
# Reset again
reset_step()
assert get_step() == 0
def test_intermediate_logging_context_manager():
"""Test the intermediate_logging context manager."""
# Create a config
config = IntermediateLoggingConfig(enabled=True)
# Initially, there should be no global config
assert get_current_il_config() is None
# Use the context manager
with intermediate_logging(config):
# Inside the context, the global config should be set
assert get_current_il_config() is not None
assert get_current_il_config().enabled is True
# After the context, the global config should be None again
assert get_current_il_config() is None
# Test with a different config
config2 = IntermediateLoggingConfig(enabled=False)
with intermediate_logging(config2):
assert get_current_il_config() is not None
assert get_current_il_config().enabled is False
def test_should_log_step():
"""Test the should_log_step function."""
# Reset step counter
reset_step()
# Create configs with different step IDs
config_all_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[] # Empty list means log all steps
)
config_specific_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
)
config_disabled = IntermediateLoggingConfig(enabled=False,
log_step_ids=[0, 1, 2])
# Test with all steps config
with intermediate_logging(config_all_steps):
assert should_log_step(config_all_steps) is True # Step 0
increment_step()
assert should_log_step(config_all_steps) is True # Step 1
# Reset step counter
reset_step()
# Test with specific steps config
with intermediate_logging(config_specific_steps):
assert should_log_step(config_specific_steps) is True # Step 0
increment_step()
assert should_log_step(config_specific_steps) is False # Step 1
increment_step()
assert should_log_step(config_specific_steps) is True # Step 2
increment_step()
assert should_log_step(config_specific_steps) is False # Step 3
increment_step()
assert should_log_step(config_specific_steps) is True # Step 4
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_step(config_disabled) is False # Disabled
def test_should_log_device():
"""Test the should_log_device function."""
# Create configs with different device filters
config_all_devices = IntermediateLoggingConfig(
enabled=True,
device_names=[] # Empty list means log all devices
)
config_specific_devices = IntermediateLoggingConfig(
enabled=True,
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
)
config_disabled = IntermediateLoggingConfig(enabled=False,
device_names=["cuda:0", "cpu"])
# Test with all devices config
with intermediate_logging(config_all_devices):
assert should_log_device(config_all_devices, "cuda:0") is True
assert should_log_device(config_all_devices, "cuda:1") is True
assert should_log_device(config_all_devices, "cpu") is True
# Test with specific devices config
with intermediate_logging(config_specific_devices):
assert should_log_device(config_specific_devices, "cuda:0") is True
assert should_log_device(config_specific_devices, "cuda:1") is False
assert should_log_device(config_specific_devices, "cpu") is True
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_device(config_disabled, "cuda:0") is False
assert should_log_device(config_disabled, "cpu") is False
def test_should_log_module(simple_model):
"""Test the should_log_module function."""
# Create configs with different module name filters
config_all_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=None # None means log all modules
)
config_specific_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=[".*linear.*"
] # Only log modules with "linear" in the name
)
config_disabled = IntermediateLoggingConfig(enabled=False,
module_call_match=[".*"])
# Test with all modules config
with intermediate_logging(config_all_modules):
assert should_log_module(config_all_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_all_modules, "relu",
simple_model.relu) is True
# Test with specific modules config
with intermediate_logging(config_specific_modules):
assert should_log_module(config_specific_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_specific_modules, "relu",
simple_model.relu) is False
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_module(config_disabled, "linear1",
simple_model.linear1) is False
assert should_log_module(config_disabled, "relu",
simple_model.relu) is False
def test_register_hooks(simple_model, il_config):
"""Test registering hooks on a model."""
# Register hooks
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Check that hooks were registered
assert len(logger_instance.hooks) > 0
# Remove hooks
logger_instance.remove_hooks()
# Check that hooks were removed
assert len(logger_instance.hooks) == 0
@mock.patch(
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
"""Test that forward hooks are called during model execution."""
mock_save_tensors.return_value = None
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that the step counter was incremented
assert get_step() == 1
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
# Remove hooks
logger_instance.remove_hooks()
def test_end_to_end(simple_model, il_config, temp_output_dir):
"""Test the entire intermediate logging workflow end-to-end."""
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that output directories were created
root_dir = Path(il_config._output_run_dir)
assert root_dir.exists()
step_dir = root_dir / "step_0"
assert step_dir.exists()
module_dirs = list(step_dir.glob("*"))
print(f"{module_dirs=}")
assert len(module_dirs) > 0
# Check that input and output files were created
for module_dir in module_dirs:
print(f"{module_dir=}")
if os.path.isdir(module_dir):
inputs_json = module_dir / "inputs.json"
outputs_json = module_dir / "outputs.json"
# Check that JSON files exist
assert inputs_json.exists()
assert outputs_json.exists()
# Check that JSON files contain valid data
with open(inputs_json) as f:
inputs_data = json.load(f)
assert "type" in inputs_data
with open(outputs_json) as f:
outputs_data = json.load(f)
assert "type" in outputs_data
# Check that tensor files exist
tensor_files = list(module_dir.glob("*.pt"))
assert len(tensor_files) > 0
# Remove hooks
logger_instance.remove_hooks()
if __name__ == "__main__":
pytest.main(["-xvs", __file__])

View File

@ -3311,6 +3311,119 @@ class KVTransferConfig:
return self.kv_connector_extra_config.get(key, default)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging."""
output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors."""
module_call_match: Optional[list[str]] = None
"""Match modules by name regex and call index (
a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """
log_step_ids: list[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps)."""
log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit)."""
enabled: bool = True
"""Whether logging is enabled."""
device_names: list[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices)."""
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
init=False)
"""Compiled regex patterns for module filtering."""
_module_call: dict[str, int] = field(default_factory=dict, init=False)
_step_id_set: set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result."""
def __post_init__(self):
"""Initialize derived fields after instance creation."""
self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids)
def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering."""
from vllm.logger import init_logger
logger = init_logger(__name__)
self._compiled_module_matches = []
if self.module_call_match is None:
logger.info(
"No module name regex patterns provided, will log all modules")
return
# Compile all patterns
for regex_pattern_call_idx in self.module_call_match:
try:
splits = regex_pattern_call_idx.split(":", 2)
regex_pattern = splits[0]
call_idx = -1
if len(splits) > 1:
call_idx = int(splits[1])
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx
logger.info("Successfully compiled regex pattern: '%s'",
regex_pattern)
except Exception as e:
logger.error("Failed to parse module_call_match '%s': %s",
regex_pattern_call_idx, e)
logger.info("Compiled %d regex patterns",
len(self._compiled_module_calls))
def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization."""
return {
"output_run_dir": self.output_run_dir,
"module_call_match": self.module_call_match,
"log_step_ids": self.log_step_ids,
"max_tensor_size": self.max_tensor_size,
"enabled": self.enabled,
"device_names": self.device_names
}
@classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@property
def output_run_dir(self) -> str:
return self._output_run_dir
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# Intermediate logging doesn't affect the computation graph
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@ -3362,6 +3475,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
"""The configurations for event publishing."""
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
"""Configuration for intermediate tensor logging."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@ -3446,6 +3561,10 @@ class VllmConfig:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.intermediate_log_config:
vllm_factors.append(self.intermediate_log_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(

View File

@ -409,6 +409,7 @@ class EngineArgs:
speculative_config: Optional[Dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
otlp_traces_endpoint: Optional[str] = \
@ -456,6 +457,8 @@ class EngineArgs:
async_scheduling: bool = SchedulerConfig.async_scheduling
intermediate_log_config: Optional[dict[str, Any]] = None
kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill
@ -883,6 +886,9 @@ class EngineArgs:
title="VllmConfig",
description=VllmConfig.__doc__,
)
vllm_group.add_argument("--intermediate-log-config",
**vllm_kwargs["intermediate_log_config"])
# We construct SpeculativeConfig using fields from other configs in
# create_engine_config. So we set the type to a JSON string here to
# delay the Pydantic validation that comes with SpeculativeConfig.
@ -1394,7 +1400,6 @@ class EngineArgs:
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
)
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@ -1409,6 +1414,7 @@ class EngineArgs:
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
intermediate_log_config=self.intermediate_log_config,
additional_config=self.additional_config,
)

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import logging
from abc import ABC, abstractmethod
@ -59,14 +57,9 @@ class ConversationContext(ABC):
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack) -> None:
pass
@abstractmethod
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class SimpleContext(ConversationContext):
@ -96,13 +89,9 @@ class SimpleContext(ConversationContext):
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack) -> None:
pass
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
@ -114,7 +103,6 @@ class HarmonyContext(ConversationContext):
self._messages = messages
self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.called_tools: set[str] = set()
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
@ -246,8 +234,7 @@ class HarmonyContext(ConversationContext):
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (recipient.startswith("browser.")
or recipient.startswith("python") or
recipient.startswith("container."))
or recipient.startswith("python"))
async def call_tool(self) -> list[Message]:
if not self.messages:
@ -261,9 +248,6 @@ class HarmonyContext(ConversationContext):
elif recipient.startswith("python"):
return await self.call_python_tool(
self._tool_sessions["python"], last_msg)
elif recipient.startswith("container."):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
@ -272,7 +256,6 @@ class HarmonyContext(ConversationContext):
async def call_search_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
@ -282,16 +265,12 @@ class HarmonyContext(ConversationContext):
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
Message(author=author, content=[content], recipient=Role.ASSISTANT)
]
async def call_python_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
@ -311,63 +290,13 @@ class HarmonyContext(ConversationContext):
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack) -> None:
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id))
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def call_container_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
]
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info("Cleaning up tool session for %s",
tool_session._client_info)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools))
self._tool_sessions[
tool_name] = await exit_stack.enter_async_context(
tool_server.new_session(tool_name))
class StreamingHarmonyContext(HarmonyContext):

View File

@ -16,13 +16,11 @@ from openai.types.responses.response_function_web_search import (
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent)
from openai.types.responses.tool import Tool
from openai_harmony import (Author, ChannelConfig, Conversation,
DeveloperContent, HarmonyEncodingName, Message,
ReasoningEffort, Role, StreamableParser,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, Message, ReasoningEffort,
Role, StreamableParser, SystemContent, TextContent,
ToolDescription, load_harmony_encoding)
from vllm import envs
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
ResponseInputOutputItem)
from vllm.utils import random_uuid
@ -35,20 +33,6 @@ REASONING_EFFORT = {
_harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: list[str]) -> bool:
return not set(tool_types).issubset(BUILTIN_TOOLS)
def get_encoding():
global _harmony_encoding
@ -64,19 +48,10 @@ def get_system_message(
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
container_description: Optional[str] = None,
instructions: Optional[str] = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if (instructions is not None
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
current_identity = sys_msg_content.model_identity
new_identity = (f'{current_identity}\n{instructions}'
if current_identity else instructions)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort])
@ -88,14 +63,6 @@ def get_system_message(
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
if not with_custom_tools:
channel_config = sys_msg_content.channel_config
invalid_channel = "commentary"
new_config = ChannelConfig.require_channels(
[c for c in channel_config.valid_channels if c != invalid_channel])
sys_msg_content = sys_msg_content.with_channel_config(new_config)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
@ -119,17 +86,14 @@ def get_developer_message(
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if (instructions is not None
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter",
"container"):
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
pass
elif tool.type == "function":
function_tools.append(tool)
else:
@ -172,8 +136,6 @@ def parse_response_input(
TextContent(text=text_prefix + c["text"]) for c in content
]
msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None

View File

@ -44,9 +44,8 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext,
SimpleContext, StreamingHarmonyContext)
from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions,
get_system_message, get_user_message, has_custom_tools,
parse_output_message, parse_remaining_state, parse_response_input,
render_for_completion)
get_system_message, get_user_message, parse_output_message,
parse_remaining_state, parse_response_input, render_for_completion)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -267,8 +266,6 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
builtin_tool_list.append("python")
if self.tool_server.has_tool("container"):
builtin_tool_list.append("container")
if self.tool_server is not None:
available_tools = builtin_tool_list
@ -451,8 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
try:
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
await context.init_tool_sessions(self.tool_server, exit_stack)
async for _ in result_generator:
pass
except asyncio.CancelledError:
@ -714,21 +710,13 @@ class OpenAIServingResponses(OpenAIServing):
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container")
enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("browser"))
enable_code_interpreter = ("code_interpreter" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("python"))
enable_container = ("container" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("container"))
with_custom_tools = has_custom_tools(tool_types)
sys_msg = get_system_message(
reasoning_effort=reasoning_effort,
browser_description=self.tool_server.get_tool_description(
@ -737,17 +725,11 @@ class OpenAIServingResponses(OpenAIServing):
python_description=self.tool_server.get_tool_description(
"python") if enable_code_interpreter
and self.tool_server is not None else None,
container_description=self.tool_server.get_tool_description(
"container")
if enable_container and self.tool_server is not None else None,
instructions=request.instructions,
with_custom_tools=with_custom_tools,
)
messages.append(sys_msg)
if with_custom_tools:
dev_msg = get_developer_message(
instructions=request.instructions, tools=request.tools)
messages.append(dev_msg)
dev_msg = get_developer_message(request.instructions,
request.tools)
messages.append(dev_msg)
else:
# Continue the previous conversation.
# FIXME(woosuk): Currently, request params like reasoning and
@ -1631,8 +1613,7 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
await context.init_tool_sessions(self.tool_server, exit_stack)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events

View File

@ -86,8 +86,7 @@ class ToolServer(ABC):
pass
@abstractmethod
def new_session(self, tool_name: str,
session_id: str) -> AbstractAsyncContextManager[Any]:
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
"""
Create a session for the tool.
"""
@ -125,8 +124,7 @@ class MCPToolServer(ToolServer):
description=tool.description,
parameters=tool.inputSchema)
for tool in list_tools_response.tools
],
)
])
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
@ -144,16 +142,14 @@ class MCPToolServer(ToolServer):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self, tool_name: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
headers = {"x-session-id": session_id}
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url,
headers=headers) as streams, ClientSession(
*streams) as session:
async with sse_client(url=url) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session
@ -186,7 +182,7 @@ class DemoToolServer(ToolServer):
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self, tool_name: str):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]

View File

@ -168,8 +168,6 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
@ -1203,15 +1201,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Allows vllm use container tool
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
# Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
lambda: bool(
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))),
# Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),

View File

@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
if current_platform.is_cuda_alike():
@ -786,7 +786,6 @@ class FusedMoE(CustomOp):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
super().__init__()
if params_dtype is None:
@ -798,10 +797,6 @@ class FusedMoE(CustomOp):
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
@ -1704,22 +1699,14 @@ class FusedMoE(CustomOp):
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
self.sp_size)
num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers,
moe_dp_chunk_size_per_rank)):
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dispatchers)
max_tokens_across_dp)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)

View File

@ -37,6 +37,8 @@ class DeepseekV2Model(nn.Module):
super().__init__()
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = self.config.vocab_size
@ -49,8 +51,11 @@ class DeepseekV2Model(nn.Module):
self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(
vllm_config,
self.config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
) for i in range(self.config.num_hidden_layers)
])

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -43,19 +43,23 @@ class SharedHead(nn.Module):
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config)
def forward(
self,
@ -91,8 +95,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
DeepSeekMultiTokenPredictorLayer(vllm_config,
f"{prefix}.layers.{idx}")
DeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})

View File

@ -32,14 +32,12 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -57,9 +55,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -76,27 +72,19 @@ class DeepseekV2MLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@ -110,58 +98,17 @@ class DeepseekV2MLP(nn.Module):
return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
config: Union[DeepseekV2Config, DeepseekV3Config],
parallel_config: ParallelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
@ -170,21 +117,6 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
@ -201,8 +133,9 @@ class DeepseekV2MoE(nn.Module):
self.gate.e_score_correction_bias = None
# Load balancing settings.
eplb_config = parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
@ -233,9 +166,7 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
num_redundant_experts=self.n_redundant_experts)
self.shared_experts = None
else:
intermediate_size = (config.moe_intermediate_size *
@ -246,7 +177,6 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
is_sequence_parallel=self.is_sequence_parallel,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
@ -269,22 +199,11 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
num_redundant_experts=self.n_redundant_experts)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# Chunk the hidden states so they aren't replicated across TP ranks.
# This avoids duplicate computation in self.experts.
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
@ -309,11 +228,7 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
@ -617,15 +532,16 @@ class DeepseekV2MLAAttention(nn.Module):
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
def __init__(
self,
config: Union[DeepseekV2Config, DeepseekV3Config],
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -662,9 +578,9 @@ class DeepseekV2DecoderLayer(nn.Module):
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
@ -734,7 +650,10 @@ class DeepseekV2Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config
self.vocab_size = config.vocab_size
@ -750,7 +669,14 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:

View File

@ -80,6 +80,10 @@ class EngineCore:
# Setup Model.
self.model_executor = executor_class(vllm_config)
if vllm_config.intermediate_log_config is not None:
self.collective_rpc("register_intermediate_hooks",
args=(vllm_config.intermediate_log_config, ))
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)

View File

@ -641,13 +641,7 @@ class WorkerProc:
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
"""Main busy loop for Multiprocessing Workers"""
import os, psutil
p = psutil.Process(os.getpid())
i = 0
while True:
if i % 100 == 0:
logger.info("WorkerProc RSS MB: %d", p.memory_info().rss/1024/1024)
i += 1
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel)
try:

View File

View File

@ -0,0 +1,405 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Module for logging intermediate tensors during model execution.
This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global step counter
_CURRENT_STEP = 0
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
IL_MODULE_NAME = "_il_module_name"
IL_MODULE_CALL_IDX = "_il_module_call_idx"
# Utility functions for intermediate logging
def should_log_step(config):
"""Check if the current step should be logged based on the step IDs.
Args:
config: The IntermediateLoggingConfig instance.
Returns:
True if the current step should be logged, False otherwise.
"""
if not is_log_enabled(config):
return False
# If log_step_ids is empty, log all steps
if not config.log_step_ids:
return True
# Otherwise, check if current step is in the set of step IDs to log
return get_step() in config._step_id_set
def should_log_device(config, device_name):
"""Check if a device should be logged based on the device names.
Args:
config: The IntermediateLoggingConfig instance.
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
Returns:
True if the device should be logged, False otherwise.
If device_names is empty, all devices are logged.
"""
if not is_log_enabled(config):
return False
# If device_names is empty, log all devices
if not config.device_names:
return True
# Otherwise, check if device_name is in the list of device names to log
return device_name in config.device_names
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
"""Check if a module should be logged based on the name regex patterns.
Args:
config: The IntermediateLoggingConfig instance.
module_name: The name of the module to check.
Returns:
True if the module should be logged, False otherwise.
If no patterns are defined, all modules are logged.
If patterns are defined, the module is logged if it matches ANY pattern.
"""
if not is_log_enabled(config):
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
# Check if the module name matches any of the patterns
for pattern, call_idx in config._compiled_module_calls.items():
match = pattern.search(module_name)
if match:
logger.debug(
"Module %s, %s matches pattern: '%s', call_idx=%s",
module_name,
module.__class__.__name__,
pattern.pattern,
call_idx,
)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, call_idx)
return True
return False
def is_log_enabled(config):
if not config or not config.enabled:
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
return False
return True
def get_il_module_name(module: torch.nn.Module) -> str:
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
def get_il_module_call_idx(module: torch.nn.Module) -> int:
return getattr(module, IL_MODULE_CALL_IDX, -1)
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
setattr(module, IL_MODULE_NAME, name)
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
setattr(module, IL_MODULE_CALL_IDX, idx)
_global_config: Optional[IntermediateLoggingConfig] = None
@contextmanager
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
"""
Temporarily sets the global config for the duration of the context.
:param config: Keyword arguments to set as global config
"""
global _global_config
old_config = _global_config
try:
_global_config = config
yield
finally:
_global_config = old_config
def get_current_il_config():
return _global_config
def save_tensors(tensor: Any, file_path: str) -> Any:
"""Utility function to dump tensor to a file.
Args:
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
tensors, or a dictionary containing tensors.
file_path: Base path where to save the tensor (without extension).
"""
if isinstance(tensor, torch.Tensor):
device_name = str(tensor.device)
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
return tensor
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
torch.save(tensor, pt_path)
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
for i, item in enumerate(tensor):
save_tensors(item, f"{file_path}_{i}")
return tensor
if isinstance(tensor, dict):
for k, v in tensor.items():
save_tensors(v, f"{file_path}_{k}")
return tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if get_current_il_config() is None:
return
# Increment the global step counter
increment_step()
global _CURRENT_STEP_MODULE_CALL_STEP
_CURRENT_STEP_MODULE_CALL_STEP = {}
def _prepare_module_log_dir(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
dump_dir = Path(
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
suffix = ""
module_call_idx = get_current_step_module_call(module_name)
if module_call_idx > 0:
suffix = f"_{module_call_idx}"
module_dir = dump_dir / (module_name + suffix)
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name,
module_dir)
return module_dir
def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
file = (Path(intermediate_log_config.output_run_dir) /
f"step_{get_step()}" / "module_calls.txt")
with open(file, "a") as f:
f.write(f"{module_name}\n")
def update_current_step_module_call(module_name: str) -> None:
logger.debug("Updating current step module call for %s", module_name)
global _CURRENT_STEP_MODULE_CALL_STEP
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
else:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module,
is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
if not should_log_step(intermediate_log_config):
return None
module_name = get_il_module_name(module)
log_call_idx = get_il_module_call_idx(module)
current_call_idx = get_current_step_module_call(module_name)
should_log = True
if log_call_idx >= 0 and current_call_idx != log_call_idx:
should_log = False
log_dir = None
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
log_dir = _prepare_module_log_dir(intermediate_log_config,
module_name,
is_pre_fwd=is_pre_fwd)
return log_dir
def log_pre_fwd_hook(module: torch.nn.Module,
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
Returns:
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
save_tensors(inputs, str(log_dir / "inputs"))
return inputs
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to capture module outputs after forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, \
"IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
def get_step() -> int:
"""Get the current global step counter.
Returns:
The current global step counter.
"""
return _CURRENT_STEP
def increment_step() -> int:
"""Increment the global step counter.
Returns:
The new step counter value.
"""
global _CURRENT_STEP
_CURRENT_STEP += 1
return _CURRENT_STEP
def reset_step() -> None:
"""Reset the global step counter to zero."""
global _CURRENT_STEP
_CURRENT_STEP = 0
class IntermediatesLogger:
"""Class to manage logging of intermediate tensors during model
execution."""
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
Optional[RemovableHandle]]] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
logger.info("Intermediates will be logged in %s",
config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
Args:
model: The PyTorch model to register hooks for.
"""
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug("Registered pre_fwd hook for %s",
module.__class__.__name__)
post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug("Registered post_fwd hook for %s",
module.__class__.__name__)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
step_hook = model.register_forward_hook(step_fwd)
self.hooks.append(("", model, None, step_hook))
logger.info("Registered hooks for %s modules", len(self.hooks))
def remove_hooks(self) -> None:
"""Remove all registered hooks."""
for _, _, pre_hook, post_hook in self.hooks:
if pre_hook is not None:
pre_hook.remove()
if post_hook is not None:
post_hook.remove()
logger.info("Removed %s hooks", len(self.hooks))
self.hooks = []
def register_intermediate_hooks(
model: torch.nn.Module,
config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
Args:
model: The PyTorch model to log intermediates for.
config: Configuration for intermediate logging. If provided, this takes
precedence over kwargs.
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance

View File

@ -27,6 +27,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
@ -362,9 +363,9 @@ class Worker(WorkerBase):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
with intermediate_logging(self.vllm_config.intermediate_log_config):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output

View File

@ -6,8 +6,10 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config import IntermediateLoggingConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.intermediates.intermediates_logging import (
register_intermediate_hooks)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
@ -63,3 +65,26 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def register_intermediate_hooks(
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
"""Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core.
It registers hooks on the model to dump intermediate tensors during
execution.
Args:
config: Configuration for intermediate logging. If provided, this
takes precedence over kwargs.
"""
if self.model_runner is None or not hasattr(
self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: "
"model_runner.model is not accessible")
return
model = self.model_runner.model
try:
register_intermediate_hooks(model, config)
except Exception:
logger.exception("Error registering intermediate hooks")

View File

@ -129,6 +129,22 @@ class WorkerBase:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def register_intermediate_hooks(self, config=None) -> None:
"""Register hooks for intermediate tensor logging.
This method is a stub for v0 workers. The actual implementation is
in v1 workers. It's included here for compatibility with the
collective_rpc mechanism.
Args:
config: Configuration for intermediate logging.
"""
logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. "
"This is only available in v1 workers. No hooks will be registered."
)
return None
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return