Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e793b9a70c | |||
| 76c9ec0ddf | |||
| 87c737016d | |||
| ba90794ff1 | |||
| ab4ab0fd28 | |||
| 2af83ebdde | |||
| d8bff253d7 |
@ -149,25 +149,3 @@ steps:
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build and publish nightly multi-arch image to DockerHub"
|
||||
depends_on:
|
||||
- create-multi-arch-manifest
|
||||
if: build.env("NIGHTLY") == "1"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly"
|
||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
||||
- "docker push vllm/vllm-openai:nightly"
|
||||
- "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
||||
# Clean up old nightly builds (keep only last 14)
|
||||
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
|
||||
plugins:
|
||||
- docker-login#v3.0.0:
|
||||
username: vllmbot
|
||||
password-env: DOCKERHUB_TOKEN
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -1,97 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Clean up old nightly builds from DockerHub, keeping only the last 14 builds
|
||||
# This script uses DockerHub API to list and delete old tags with "nightly-" prefix
|
||||
|
||||
# DockerHub API endpoint for vllm/vllm-openai repository
|
||||
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
|
||||
|
||||
# Get DockerHub token from environment
|
||||
if [ -z "$DOCKERHUB_TOKEN" ]; then
|
||||
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Function to get all tags from DockerHub
|
||||
get_all_tags() {
|
||||
local page=1
|
||||
local all_tags=""
|
||||
|
||||
while true; do
|
||||
local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \
|
||||
"$REPO_API_URL?page=$page&page_size=100")
|
||||
|
||||
# Get both last_updated timestamp and tag name, separated by |
|
||||
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
|
||||
|
||||
if [ -z "$tags" ]; then
|
||||
break
|
||||
fi
|
||||
|
||||
all_tags="$all_tags$tags"$'\n'
|
||||
page=$((page + 1))
|
||||
done
|
||||
|
||||
# Sort by timestamp (newest first) and extract just the tag names
|
||||
echo "$all_tags" | sort -r | cut -d'|' -f2
|
||||
}
|
||||
|
||||
delete_tag() {
|
||||
local tag_name="$1"
|
||||
echo "Deleting tag: $tag_name"
|
||||
|
||||
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
|
||||
local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url")
|
||||
|
||||
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
|
||||
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
|
||||
else
|
||||
echo "Successfully deleted tag: $tag_name"
|
||||
fi
|
||||
}
|
||||
|
||||
# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first)
|
||||
echo "Fetching all tags from DockerHub..."
|
||||
all_tags=$(get_all_tags)
|
||||
|
||||
if [ -z "$all_tags" ]; then
|
||||
echo "No tags found to clean up"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Count total tags
|
||||
total_tags=$(echo "$all_tags" | wc -l)
|
||||
echo "Found $total_tags tags"
|
||||
|
||||
# Keep only the last 14 builds (including the current one)
|
||||
tags_to_keep=14
|
||||
tags_to_delete=$((total_tags - tags_to_keep))
|
||||
|
||||
if [ $tags_to_delete -le 0 ]; then
|
||||
echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep"
|
||||
|
||||
# Get tags to delete (skip the first $tags_to_keep tags)
|
||||
tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1)))
|
||||
|
||||
if [ -z "$tags_to_delete_list" ]; then
|
||||
echo "No tags to delete"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Delete old tags
|
||||
echo "Deleting old tags..."
|
||||
while IFS= read -r tag; do
|
||||
if [ -n "$tag" ]; then
|
||||
delete_tag "$tag"
|
||||
# Add a small delay to avoid rate limiting
|
||||
sleep 1
|
||||
fi
|
||||
done <<< "$tags_to_delete_list"
|
||||
|
||||
echo "Cleanup completed successfully"
|
||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -273,20 +273,6 @@ pull_request_rules:
|
||||
users:
|
||||
- "sangstar"
|
||||
|
||||
- name: assign reviewer for modelopt changes
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
|
||||
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
|
||||
- files~=^tests/models/quantization/test_modelopt\.py$
|
||||
- files~=^tests/quantization/test_modelopt\.py$
|
||||
- files~=^tests/models/quantization/test_nvfp4\.py$
|
||||
- files~=^docs/features/quantization/modelopt\.md$
|
||||
actions:
|
||||
assign:
|
||||
users:
|
||||
- "Edwardf0t1"
|
||||
|
||||
- name: remove 'needs-rebase' label when conflict is resolved
|
||||
conditions:
|
||||
- -conflict
|
||||
|
||||
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
|
||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
|
||||
65
docs/contributing/intermediate_logging.md
Normal file
65
docs/contributing/intermediate_logging.md
Normal file
@ -0,0 +1,65 @@
|
||||
# Intermediate Tensor Logging
|
||||
|
||||
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
|
||||
|
||||
## Overview
|
||||
|
||||
The intermediate tensor logging feature enables you to:
|
||||
|
||||
- Log input and output tensors from a configured set of filters
|
||||
- Filter modules by name using regex patterns
|
||||
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
|
||||
- Filter tensors by device
|
||||
- Filter whole model fwd step id
|
||||
|
||||
## Usage
|
||||
|
||||
### Enabling via parameters or config file
|
||||
|
||||
**Offline Inference example**
|
||||
|
||||
Dump all modules, all devices for step 0 (default behavior)
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
|
||||
```
|
||||
|
||||
Dump first layers module, all devices for step 0
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
|
||||
```
|
||||
|
||||
|
||||
#### Configuration Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
|
||||
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
|
||||
| `log_step_ids` | array | List of step IDs to log | `[0]` |
|
||||
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
|
||||
| `device_names` | array | List of device names to log | `[]` (log all devices) |
|
||||
|
||||
### Output Directory Structure
|
||||
|
||||
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
|
||||
|
||||
```
|
||||
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
|
||||
└── step_0
|
||||
├── model.embed_tokens
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
├── model.layers.0.input_layernorm
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
└── step_1/
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each tensor is saved in a `.pt` file containing the full PyTorch tensors (can be loaded with `torch.load()`)
|
||||
320
tests/v1/test_intermediates_logging.py
Normal file
320
tests/v1/test_intermediates_logging.py
Normal file
@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the intermediate tensor logging functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
get_current_il_config, get_step, increment_step, intermediate_logging,
|
||||
register_intermediate_hooks, reset_step, should_log_device,
|
||||
should_log_module, should_log_step)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
"""A simple model for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(10, 20)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(20, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""Create a temporary directory for test outputs."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# Clean up after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_model():
|
||||
"""Create a simple model for testing."""
|
||||
return SimpleModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def il_config(temp_output_dir):
|
||||
"""Create a basic IntermediateLoggingConfig for testing."""
|
||||
return IntermediateLoggingConfig(output_dir=temp_output_dir,
|
||||
enabled=True,
|
||||
log_step_ids=[0, 1],
|
||||
module_call_match=[".*linear.*"])
|
||||
|
||||
|
||||
def test_step_counter():
|
||||
"""Test the step counter functionality."""
|
||||
# Reset the step counter
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
# Increment the step counter
|
||||
increment_step()
|
||||
assert get_step() == 1
|
||||
|
||||
# Increment again
|
||||
increment_step()
|
||||
assert get_step() == 2
|
||||
|
||||
# Reset again
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
|
||||
def test_intermediate_logging_context_manager():
|
||||
"""Test the intermediate_logging context manager."""
|
||||
# Create a config
|
||||
config = IntermediateLoggingConfig(enabled=True)
|
||||
|
||||
# Initially, there should be no global config
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Use the context manager
|
||||
with intermediate_logging(config):
|
||||
# Inside the context, the global config should be set
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is True
|
||||
|
||||
# After the context, the global config should be None again
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Test with a different config
|
||||
config2 = IntermediateLoggingConfig(enabled=False)
|
||||
with intermediate_logging(config2):
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is False
|
||||
|
||||
|
||||
def test_should_log_step():
|
||||
"""Test the should_log_step function."""
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Create configs with different step IDs
|
||||
config_all_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[] # Empty list means log all steps
|
||||
)
|
||||
config_specific_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
log_step_ids=[0, 1, 2])
|
||||
|
||||
# Test with all steps config
|
||||
with intermediate_logging(config_all_steps):
|
||||
assert should_log_step(config_all_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_all_steps) is True # Step 1
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Test with specific steps config
|
||||
with intermediate_logging(config_specific_steps):
|
||||
assert should_log_step(config_specific_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 1
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 2
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 3
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 4
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_step(config_disabled) is False # Disabled
|
||||
|
||||
|
||||
def test_should_log_device():
|
||||
"""Test the should_log_device function."""
|
||||
# Create configs with different device filters
|
||||
config_all_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=[] # Empty list means log all devices
|
||||
)
|
||||
config_specific_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
device_names=["cuda:0", "cpu"])
|
||||
|
||||
# Test with all devices config
|
||||
with intermediate_logging(config_all_devices):
|
||||
assert should_log_device(config_all_devices, "cuda:0") is True
|
||||
assert should_log_device(config_all_devices, "cuda:1") is True
|
||||
assert should_log_device(config_all_devices, "cpu") is True
|
||||
|
||||
# Test with specific devices config
|
||||
with intermediate_logging(config_specific_devices):
|
||||
assert should_log_device(config_specific_devices, "cuda:0") is True
|
||||
assert should_log_device(config_specific_devices, "cuda:1") is False
|
||||
assert should_log_device(config_specific_devices, "cpu") is True
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_device(config_disabled, "cuda:0") is False
|
||||
assert should_log_device(config_disabled, "cpu") is False
|
||||
|
||||
|
||||
def test_should_log_module(simple_model):
|
||||
"""Test the should_log_module function."""
|
||||
# Create configs with different module name filters
|
||||
config_all_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=None # None means log all modules
|
||||
)
|
||||
config_specific_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=[".*linear.*"
|
||||
] # Only log modules with "linear" in the name
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
module_call_match=[".*"])
|
||||
|
||||
# Test with all modules config
|
||||
with intermediate_logging(config_all_modules):
|
||||
assert should_log_module(config_all_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_all_modules, "relu",
|
||||
simple_model.relu) is True
|
||||
|
||||
# Test with specific modules config
|
||||
with intermediate_logging(config_specific_modules):
|
||||
assert should_log_module(config_specific_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_specific_modules, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_module(config_disabled, "linear1",
|
||||
simple_model.linear1) is False
|
||||
assert should_log_module(config_disabled, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
|
||||
def test_register_hooks(simple_model, il_config):
|
||||
"""Test registering hooks on a model."""
|
||||
# Register hooks
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Check that hooks were registered
|
||||
assert len(logger_instance.hooks) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
# Check that hooks were removed
|
||||
assert len(logger_instance.hooks) == 0
|
||||
|
||||
|
||||
@mock.patch(
|
||||
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
|
||||
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||
il_config, temp_output_dir):
|
||||
"""Test that forward hooks are called during model execution."""
|
||||
mock_save_tensors.return_value = None
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that the step counter was incremented
|
||||
assert get_step() == 1
|
||||
|
||||
# Check that dump_intermediates_to_json and save_tensors were called
|
||||
assert mock_dump_json.called
|
||||
assert mock_save_tensors.called
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
def test_end_to_end(simple_model, il_config, temp_output_dir):
|
||||
"""Test the entire intermediate logging workflow end-to-end."""
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that output directories were created
|
||||
root_dir = Path(il_config._output_run_dir)
|
||||
assert root_dir.exists()
|
||||
step_dir = root_dir / "step_0"
|
||||
assert step_dir.exists()
|
||||
|
||||
module_dirs = list(step_dir.glob("*"))
|
||||
print(f"{module_dirs=}")
|
||||
assert len(module_dirs) > 0
|
||||
|
||||
# Check that input and output files were created
|
||||
for module_dir in module_dirs:
|
||||
print(f"{module_dir=}")
|
||||
if os.path.isdir(module_dir):
|
||||
inputs_json = module_dir / "inputs.json"
|
||||
outputs_json = module_dir / "outputs.json"
|
||||
|
||||
# Check that JSON files exist
|
||||
assert inputs_json.exists()
|
||||
assert outputs_json.exists()
|
||||
|
||||
# Check that JSON files contain valid data
|
||||
with open(inputs_json) as f:
|
||||
inputs_data = json.load(f)
|
||||
assert "type" in inputs_data
|
||||
|
||||
with open(outputs_json) as f:
|
||||
outputs_data = json.load(f)
|
||||
assert "type" in outputs_data
|
||||
|
||||
# Check that tensor files exist
|
||||
tensor_files = list(module_dir.glob("*.pt"))
|
||||
assert len(tensor_files) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
@ -3311,6 +3311,119 @@ class KVTransferConfig:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class IntermediateLoggingConfig:
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
|
||||
output_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Directory where to save the intermediate tensors."""
|
||||
|
||||
module_call_match: Optional[list[str]] = None
|
||||
"""Match modules by name regex and call index (
|
||||
a module can be called multiple times in a step)
|
||||
List of regex:call_idx, call_idx is -1 for default for all calls """
|
||||
|
||||
log_step_ids: list[int] = field(default_factory=lambda: [0])
|
||||
"""List of step IDs to log (empty list means log all steps)."""
|
||||
|
||||
log_post_fwd_inputs: bool = False
|
||||
"""Whether logging inputs after forwards for each module"""
|
||||
|
||||
max_tensor_size: Optional[int] = None
|
||||
"""Maximum number of elements in tensors to log (None = no limit)."""
|
||||
|
||||
enabled: bool = True
|
||||
"""Whether logging is enabled."""
|
||||
device_names: list[str] = field(default_factory=list)
|
||||
"""List of device names to log (empty list means log all devices)."""
|
||||
|
||||
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
|
||||
init=False)
|
||||
"""Compiled regex patterns for module filtering."""
|
||||
|
||||
_module_call: dict[str, int] = field(default_factory=dict, init=False)
|
||||
_step_id_set: set[int] = field(default_factory=set, init=False)
|
||||
"""Set of step IDs for faster lookup."""
|
||||
_output_run_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Unique directory to save single run/serve logging result."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize derived fields after instance creation."""
|
||||
self._compile_regex_patterns()
|
||||
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
|
||||
self._step_id_set = set(self.log_step_ids)
|
||||
|
||||
def _compile_regex_patterns(self):
|
||||
"""Compile regex patterns for module name filtering."""
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
self._compiled_module_matches = []
|
||||
|
||||
if self.module_call_match is None:
|
||||
logger.info(
|
||||
"No module name regex patterns provided, will log all modules")
|
||||
return
|
||||
|
||||
# Compile all patterns
|
||||
for regex_pattern_call_idx in self.module_call_match:
|
||||
try:
|
||||
splits = regex_pattern_call_idx.split(":", 2)
|
||||
regex_pattern = splits[0]
|
||||
call_idx = -1
|
||||
if len(splits) > 1:
|
||||
call_idx = int(splits[1])
|
||||
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
|
||||
self._compiled_module_calls[compiled_pattern] = call_idx
|
||||
logger.info("Successfully compiled regex pattern: '%s'",
|
||||
regex_pattern)
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse module_call_match '%s': %s",
|
||||
regex_pattern_call_idx, e)
|
||||
|
||||
logger.info("Compiled %d regex patterns",
|
||||
len(self._compiled_module_calls))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the config to a dictionary for serialization."""
|
||||
return {
|
||||
"output_run_dir": self.output_run_dir,
|
||||
"module_call_match": self.module_call_match,
|
||||
"log_step_ids": self.log_step_ids,
|
||||
"max_tensor_size": self.max_tensor_size,
|
||||
"enabled": self.enabled,
|
||||
"device_names": self.device_names
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
@property
|
||||
def output_run_dir(self) -> str:
|
||||
return self._output_run_dir
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# Intermediate logging doesn't affect the computation graph
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
@ -3362,6 +3475,8 @@ class VllmConfig:
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -3446,6 +3561,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.intermediate_log_config:
|
||||
vllm_factors.append(self.intermediate_log_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
|
||||
@ -409,6 +409,7 @@ class EngineArgs:
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = \
|
||||
ObservabilityConfig.show_hidden_metrics_for_version
|
||||
otlp_traces_endpoint: Optional[str] = \
|
||||
@ -456,6 +457,8 @@ class EngineArgs:
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
|
||||
intermediate_log_config: Optional[dict[str, Any]] = None
|
||||
|
||||
kv_sharing_fast_prefill: bool = \
|
||||
CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
@ -883,6 +886,9 @@ class EngineArgs:
|
||||
title="VllmConfig",
|
||||
description=VllmConfig.__doc__,
|
||||
)
|
||||
|
||||
vllm_group.add_argument("--intermediate-log-config",
|
||||
**vllm_kwargs["intermediate_log_config"])
|
||||
# We construct SpeculativeConfig using fields from other configs in
|
||||
# create_engine_config. So we set the type to a JSON string here to
|
||||
# delay the Pydantic validation that comes with SpeculativeConfig.
|
||||
@ -1394,7 +1400,6 @@ class EngineArgs:
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
collect_detailed_traces=self.collect_detailed_traces,
|
||||
)
|
||||
|
||||
config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@ -1409,6 +1414,7 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
intermediate_log_config=self.intermediate_log_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@ -59,14 +57,9 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
@ -96,13 +89,9 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
@ -114,7 +103,6 @@ class HarmonyContext(ConversationContext):
|
||||
self._messages = messages
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
@ -246,8 +234,7 @@ class HarmonyContext(ConversationContext):
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (recipient.startswith("browser.")
|
||||
or recipient.startswith("python") or
|
||||
recipient.startswith("container."))
|
||||
or recipient.startswith("python"))
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
@ -261,9 +248,6 @@ class HarmonyContext(ConversationContext):
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
@ -272,7 +256,6 @@ class HarmonyContext(ConversationContext):
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
@ -282,16 +265,12 @@ class HarmonyContext(ConversationContext):
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
Message(author=author, content=[content], recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
@ -311,63 +290,13 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info("Cleaning up tool session for %s",
|
||||
tool_session._client_info)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools))
|
||||
self._tool_sessions[
|
||||
tool_name] = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
@ -16,13 +16,11 @@ from openai.types.responses.response_function_web_search import (
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (Author, ChannelConfig, Conversation,
|
||||
DeveloperContent, HarmonyEncodingName, Message,
|
||||
ReasoningEffort, Role, StreamableParser,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, Message, ReasoningEffort,
|
||||
Role, StreamableParser, SystemContent, TextContent,
|
||||
ToolDescription, load_harmony_encoding)
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||
ResponseInputOutputItem)
|
||||
from vllm.utils import random_uuid
|
||||
@ -35,20 +33,6 @@ REASONING_EFFORT = {
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
# Builtin tools that should be included in the system message when
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
BUILTIN_TOOLS = {
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
}
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: list[str]) -> bool:
|
||||
return not set(tool_types).issubset(BUILTIN_TOOLS)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
@ -64,19 +48,10 @@ def get_system_message(
|
||||
start_date: Optional[str] = None,
|
||||
browser_description: Optional[str] = None,
|
||||
python_description: Optional[str] = None,
|
||||
container_description: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
with_custom_tools: bool = False,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if (instructions is not None
|
||||
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
current_identity = sys_msg_content.model_identity
|
||||
new_identity = (f'{current_identity}\n{instructions}'
|
||||
if current_identity else instructions)
|
||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort])
|
||||
@ -88,14 +63,6 @@ def get_system_message(
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
if container_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
||||
if not with_custom_tools:
|
||||
channel_config = sys_msg_content.channel_config
|
||||
invalid_channel = "commentary"
|
||||
new_config = ChannelConfig.require_channels(
|
||||
[c for c in channel_config.valid_channels if c != invalid_channel])
|
||||
sys_msg_content = sys_msg_content.with_channel_config(new_config)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
@ -119,17 +86,14 @@ def get_developer_message(
|
||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if (instructions is not None
|
||||
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
if instructions is not None:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter",
|
||||
"container"):
|
||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
@ -172,8 +136,6 @@ def parse_response_input(
|
||||
TextContent(text=text_prefix + c["text"]) for c in content
|
||||
]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: Optional[ResponseFunctionToolCall] = None
|
||||
|
||||
@ -44,9 +44,8 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext,
|
||||
SimpleContext, StreamingHarmonyContext)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_system_message, get_user_message, has_custom_tools,
|
||||
parse_output_message, parse_remaining_state, parse_response_input,
|
||||
render_for_completion)
|
||||
get_system_message, get_user_message, parse_output_message,
|
||||
parse_remaining_state, parse_response_input, render_for_completion)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -267,8 +266,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
builtin_tool_list.append("python")
|
||||
if self.tool_server.has_tool("container"):
|
||||
builtin_tool_list.append("container")
|
||||
|
||||
if self.tool_server is not None:
|
||||
available_tools = builtin_tool_list
|
||||
@ -451,8 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
@ -714,21 +710,13 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# New conversation.
|
||||
reasoning_effort = (request.reasoning.effort
|
||||
if request.reasoning else None)
|
||||
# Temporary: OpenAI types doesn't have container tool
|
||||
# so we used MCP to cover that, up for change
|
||||
tool_types = [tool.type for tool in request.tools]
|
||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
||||
tool_types.append("container")
|
||||
enable_browser = ("web_search_preview" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("browser"))
|
||||
enable_code_interpreter = ("code_interpreter" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("python"))
|
||||
enable_container = ("container" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("container"))
|
||||
with_custom_tools = has_custom_tools(tool_types)
|
||||
sys_msg = get_system_message(
|
||||
reasoning_effort=reasoning_effort,
|
||||
browser_description=self.tool_server.get_tool_description(
|
||||
@ -737,17 +725,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
python_description=self.tool_server.get_tool_description(
|
||||
"python") if enable_code_interpreter
|
||||
and self.tool_server is not None else None,
|
||||
container_description=self.tool_server.get_tool_description(
|
||||
"container")
|
||||
if enable_container and self.tool_server is not None else None,
|
||||
instructions=request.instructions,
|
||||
with_custom_tools=with_custom_tools,
|
||||
)
|
||||
messages.append(sys_msg)
|
||||
if with_custom_tools:
|
||||
dev_msg = get_developer_message(
|
||||
instructions=request.instructions, tools=request.tools)
|
||||
messages.append(dev_msg)
|
||||
dev_msg = get_developer_message(request.instructions,
|
||||
request.tools)
|
||||
messages.append(dev_msg)
|
||||
else:
|
||||
# Continue the previous conversation.
|
||||
# FIXME(woosuk): Currently, request params like reasoning and
|
||||
@ -1631,8 +1613,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
processer = None
|
||||
if self.use_harmony:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
processer = self._process_harmony_streaming_events
|
||||
else:
|
||||
processer = self._process_simple_streaming_events
|
||||
|
||||
@ -86,8 +86,7 @@ class ToolServer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def new_session(self, tool_name: str,
|
||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
||||
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
|
||||
"""
|
||||
Create a session for the tool.
|
||||
"""
|
||||
@ -125,8 +124,7 @@ class MCPToolServer(ToolServer):
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema)
|
||||
for tool in list_tools_response.tools
|
||||
],
|
||||
)
|
||||
])
|
||||
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||
if tool_from_mcp.name not in self.urls:
|
||||
self.urls[tool_from_mcp.name] = url
|
||||
@ -144,16 +142,14 @@ class MCPToolServer(ToolServer):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self, tool_name: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
url = self.urls.get(tool_name)
|
||||
headers = {"x-session-id": session_id}
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(url=url,
|
||||
headers=headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
async with sse_client(url=url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
@ -186,7 +182,7 @@ class DemoToolServer(ToolServer):
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self, tool_name: str):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -168,8 +168,6 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
|
||||
|
||||
@ -1203,15 +1201,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
# Allows vllm use container tool
|
||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
||||
|
||||
# Allows harmony instructions to be injected on system messages
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
||||
lambda: bool(
|
||||
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))),
|
||||
|
||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -786,7 +786,6 @@ class FusedMoE(CustomOp):
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@ -798,10 +797,6 @@ class FusedMoE(CustomOp):
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@ -1704,22 +1699,14 @@ class FusedMoE(CustomOp):
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
|
||||
self.sp_size)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers,
|
||||
moe_dp_chunk_size_per_rank)):
|
||||
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||
max_tokens_across_dispatchers)
|
||||
max_tokens_across_dp)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
|
||||
@ -37,6 +37,8 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
@ -49,8 +51,11 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekV2DecoderLayer(
|
||||
vllm_config,
|
||||
self.config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -43,19 +43,23 @@ class SharedHead(nn.Module):
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
|
||||
cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -91,8 +95,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
DeepSeekMultiTokenPredictorLayer(vllm_config,
|
||||
f"{prefix}.layers.{idx}")
|
||||
DeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
@ -32,14 +32,12 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -57,9 +55,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -76,27 +72,19 @@ class DeepseekV2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
is_sequence_parallel=False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# If is_sequence_parallel, the input and output tensors are sharded
|
||||
# across the ranks within the tp_group. In this case the weights are
|
||||
# replicated and no collective ops are needed.
|
||||
# Otherwise we use standard TP with an allreduce at the end.
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
disable_tp=is_sequence_parallel,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=is_sequence_parallel,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -110,58 +98,17 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
mutates_args=[],
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
parallel_config: ParallelConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -170,21 +117,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -201,8 +133,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
# Load balancing settings.
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
@ -233,9 +166,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
@ -246,7 +177,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
@ -269,22 +199,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Chunk the hidden states so they aren't replicated across TP ranks.
|
||||
# This avoids duplicate computation in self.experts.
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@ -309,11 +228,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
@ -617,15 +532,16 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -662,9 +578,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
parallel_config=parallel_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
@ -734,7 +650,10 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -750,7 +669,14 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
||||
lambda prefix: DeepseekV2DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
@ -80,6 +80,10 @@ class EngineCore:
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if vllm_config.intermediate_log_config is not None:
|
||||
self.collective_rpc("register_intermediate_hooks",
|
||||
args=(vllm_config.intermediate_log_config, ))
|
||||
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
||||
@ -641,13 +641,7 @@ class WorkerProc:
|
||||
|
||||
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
import os, psutil
|
||||
p = psutil.Process(os.getpid())
|
||||
i = 0
|
||||
while True:
|
||||
if i % 100 == 0:
|
||||
logger.info("WorkerProc RSS MB: %d", p.memory_info().rss/1024/1024)
|
||||
i += 1
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
cancel=cancel)
|
||||
try:
|
||||
|
||||
0
vllm/v1/intermediates/__init__.py
Normal file
0
vllm/v1/intermediates/__init__.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
@ -0,0 +1,405 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Module for logging intermediate tensors during model execution.
|
||||
|
||||
This module provides functionality to capture and save intermediate tensors
|
||||
(inputs and outputs) from PyTorch modules during forward passes.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global step counter
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
|
||||
|
||||
IL_MODULE_NAME = "_il_module_name"
|
||||
IL_MODULE_CALL_IDX = "_il_module_call_idx"
|
||||
|
||||
# Utility functions for intermediate logging
|
||||
|
||||
|
||||
def should_log_step(config):
|
||||
"""Check if the current step should be logged based on the step IDs.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
|
||||
Returns:
|
||||
True if the current step should be logged, False otherwise.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
|
||||
# If log_step_ids is empty, log all steps
|
||||
if not config.log_step_ids:
|
||||
return True
|
||||
|
||||
# Otherwise, check if current step is in the set of step IDs to log
|
||||
return get_step() in config._step_id_set
|
||||
|
||||
|
||||
def should_log_device(config, device_name):
|
||||
"""Check if a device should be logged based on the device names.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
|
||||
|
||||
Returns:
|
||||
True if the device should be logged, False otherwise.
|
||||
If device_names is empty, all devices are logged.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If device_names is empty, log all devices
|
||||
if not config.device_names:
|
||||
return True
|
||||
|
||||
# Otherwise, check if device_name is in the list of device names to log
|
||||
return device_name in config.device_names
|
||||
|
||||
|
||||
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||
"""Check if a module should be logged based on the name regex patterns.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
module_name: The name of the module to check.
|
||||
|
||||
Returns:
|
||||
True if the module should be logged, False otherwise.
|
||||
If no patterns are defined, all modules are logged.
|
||||
If patterns are defined, the module is logged if it matches ANY pattern.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If no patterns are defined, log all modules
|
||||
if not config._compiled_module_calls:
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, -1)
|
||||
return True
|
||||
|
||||
# Check if the module name matches any of the patterns
|
||||
for pattern, call_idx in config._compiled_module_calls.items():
|
||||
match = pattern.search(module_name)
|
||||
if match:
|
||||
logger.debug(
|
||||
"Module %s, %s matches pattern: '%s', call_idx=%s",
|
||||
module_name,
|
||||
module.__class__.__name__,
|
||||
pattern.pattern,
|
||||
call_idx,
|
||||
)
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, call_idx)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_log_enabled(config):
|
||||
if not config or not config.enabled:
|
||||
return False
|
||||
if torch.compiler.is_compiling():
|
||||
logger.debug("Not logging because torch.compile is in progress")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_il_module_name(module: torch.nn.Module) -> str:
|
||||
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
|
||||
|
||||
|
||||
def get_il_module_call_idx(module: torch.nn.Module) -> int:
|
||||
return getattr(module, IL_MODULE_CALL_IDX, -1)
|
||||
|
||||
|
||||
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
|
||||
setattr(module, IL_MODULE_NAME, name)
|
||||
|
||||
|
||||
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
|
||||
setattr(module, IL_MODULE_CALL_IDX, idx)
|
||||
|
||||
|
||||
_global_config: Optional[IntermediateLoggingConfig] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
|
||||
"""
|
||||
Temporarily sets the global config for the duration of the context.
|
||||
:param config: Keyword arguments to set as global config
|
||||
"""
|
||||
global _global_config
|
||||
old_config = _global_config
|
||||
try:
|
||||
_global_config = config
|
||||
yield
|
||||
finally:
|
||||
_global_config = old_config
|
||||
|
||||
|
||||
def get_current_il_config():
|
||||
return _global_config
|
||||
|
||||
|
||||
def save_tensors(tensor: Any, file_path: str) -> Any:
|
||||
"""Utility function to dump tensor to a file.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
|
||||
tensors, or a dictionary containing tensors.
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
device_name = str(tensor.device)
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if not should_log_device(intermediate_log_config, device_name):
|
||||
return tensor
|
||||
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
|
||||
try:
|
||||
torch.save(tensor, pt_path)
|
||||
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
|
||||
pt_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
|
||||
return tensor
|
||||
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
for i, item in enumerate(tensor):
|
||||
save_tensors(item, f"{file_path}_{i}")
|
||||
return tensor
|
||||
if isinstance(tensor, dict):
|
||||
for k, v in tensor.items():
|
||||
save_tensors(v, f"{file_path}_{k}")
|
||||
return tensor
|
||||
|
||||
|
||||
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to increment the global step counter after a forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if get_current_il_config() is None:
|
||||
return
|
||||
# Increment the global step counter
|
||||
increment_step()
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
_CURRENT_STEP_MODULE_CALL_STEP = {}
|
||||
|
||||
|
||||
def _prepare_module_log_dir(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
is_pre_fwd: bool = False,
|
||||
) -> Path:
|
||||
# Create a unique directory for this step if not
|
||||
dump_dir = Path(
|
||||
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||
dump_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Create module directory
|
||||
suffix = ""
|
||||
module_call_idx = get_current_step_module_call(module_name)
|
||||
if module_call_idx > 0:
|
||||
suffix = f"_{module_call_idx}"
|
||||
module_dir = dump_dir / (module_name + suffix)
|
||||
if is_pre_fwd:
|
||||
_log_module_call(intermediate_log_config, module_name + suffix)
|
||||
module_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.debug("Logging module %s inputs/outputs to %s", module_name,
|
||||
module_dir)
|
||||
return module_dir
|
||||
|
||||
|
||||
def _log_module_call(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
) -> None:
|
||||
file = (Path(intermediate_log_config.output_run_dir) /
|
||||
f"step_{get_step()}" / "module_calls.txt")
|
||||
with open(file, "a") as f:
|
||||
f.write(f"{module_name}\n")
|
||||
|
||||
|
||||
def update_current_step_module_call(module_name: str) -> None:
|
||||
logger.debug("Updating current step module call for %s", module_name)
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
|
||||
else:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
|
||||
|
||||
|
||||
def get_current_step_module_call(module_name: str) -> int:
|
||||
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
|
||||
|
||||
|
||||
def prepare_log_current_fwd(module,
|
||||
is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None or not intermediate_log_config.enabled:
|
||||
return None
|
||||
if not should_log_step(intermediate_log_config):
|
||||
return None
|
||||
|
||||
module_name = get_il_module_name(module)
|
||||
log_call_idx = get_il_module_call_idx(module)
|
||||
current_call_idx = get_current_step_module_call(module_name)
|
||||
should_log = True
|
||||
if log_call_idx >= 0 and current_call_idx != log_call_idx:
|
||||
should_log = False
|
||||
|
||||
log_dir = None
|
||||
if is_pre_fwd:
|
||||
update_current_step_module_call(module_name)
|
||||
if should_log:
|
||||
log_dir = _prepare_module_log_dir(intermediate_log_config,
|
||||
module_name,
|
||||
is_pre_fwd=is_pre_fwd)
|
||||
return log_dir
|
||||
|
||||
|
||||
def log_pre_fwd_hook(module: torch.nn.Module,
|
||||
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""Hook to capture module inputs before forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
|
||||
Returns:
|
||||
The unchanged inputs.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
|
||||
save_tensors(inputs, str(log_dir / "inputs"))
|
||||
return inputs
|
||||
|
||||
|
||||
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to capture module outputs after forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
|
||||
save_tensors(outputs, str(log_dir / "outputs"))
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None, \
|
||||
"IL config should not be None"
|
||||
if intermediate_log_config.log_post_fwd_inputs:
|
||||
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
|
||||
|
||||
|
||||
def get_step() -> int:
|
||||
"""Get the current global step counter.
|
||||
|
||||
Returns:
|
||||
The current global step counter.
|
||||
"""
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def increment_step() -> int:
|
||||
"""Increment the global step counter.
|
||||
|
||||
Returns:
|
||||
The new step counter value.
|
||||
"""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP += 1
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def reset_step() -> None:
|
||||
"""Reset the global step counter to zero."""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
|
||||
class IntermediatesLogger:
|
||||
"""Class to manage logging of intermediate tensors during model
|
||||
execution."""
|
||||
|
||||
def __init__(self, config: IntermediateLoggingConfig):
|
||||
self.config = config
|
||||
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
|
||||
Optional[RemovableHandle]]] = []
|
||||
logger.debug("Created IntermediatesLogger with config: %s", config)
|
||||
path = Path(config.output_run_dir)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
# Log configuration
|
||||
logger.info("Intermediates will be logged in %s",
|
||||
config.output_run_dir)
|
||||
|
||||
def register_hooks(self, model: torch.nn.Module) -> None:
|
||||
"""Register hooks for the model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to register hooks for.
|
||||
"""
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name and should_log_module(self.config, name, module):
|
||||
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
|
||||
logger.debug("Registered pre_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
post_hook = module.register_forward_hook(log_post_fwd_hook)
|
||||
logger.debug("Registered post_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
self.hooks.append((name, module, pre_hook, post_hook))
|
||||
|
||||
# Register a step counter hook for the root model
|
||||
step_hook = model.register_forward_hook(step_fwd)
|
||||
self.hooks.append(("", model, None, step_hook))
|
||||
logger.info("Registered hooks for %s modules", len(self.hooks))
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""Remove all registered hooks."""
|
||||
for _, _, pre_hook, post_hook in self.hooks:
|
||||
if pre_hook is not None:
|
||||
pre_hook.remove()
|
||||
if post_hook is not None:
|
||||
post_hook.remove()
|
||||
|
||||
logger.info("Removed %s hooks", len(self.hooks))
|
||||
self.hooks = []
|
||||
|
||||
|
||||
def register_intermediate_hooks(
|
||||
model: torch.nn.Module,
|
||||
config: Optional[IntermediateLoggingConfig] = None
|
||||
) -> IntermediatesLogger:
|
||||
"""Register hooks to log intermediate tensors for a model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to log intermediates for.
|
||||
config: Configuration for intermediate logging. If provided, this takes
|
||||
precedence over kwargs.
|
||||
|
||||
Returns:
|
||||
An IntermediatesLogger instance that can be used to manage the hooks.
|
||||
"""
|
||||
logger_instance = IntermediatesLogger(config)
|
||||
logger_instance.register_hooks(model)
|
||||
return logger_instance
|
||||
@ -27,6 +27,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
@ -362,9 +363,9 @@ class Worker(WorkerBase):
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
with intermediate_logging(self.vllm_config.intermediate_log_config):
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
|
||||
@ -6,8 +6,10 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import IntermediateLoggingConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
register_intermediate_hooks)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
@ -63,3 +65,26 @@ class WorkerBase(WorkerBaseV0):
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
|
||||
def register_intermediate_hooks(
|
||||
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is called via collective_rpc from the engine core.
|
||||
It registers hooks on the model to dump intermediate tensors during
|
||||
execution.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging. If provided, this
|
||||
takes precedence over kwargs.
|
||||
"""
|
||||
if self.model_runner is None or not hasattr(
|
||||
self.model_runner, "model") or self.model_runner.model is None:
|
||||
logger.error("Could not register intermediate hooks: "
|
||||
"model_runner.model is not accessible")
|
||||
return
|
||||
model = self.model_runner.model
|
||||
try:
|
||||
register_intermediate_hooks(model, config)
|
||||
except Exception:
|
||||
logger.exception("Error registering intermediate hooks")
|
||||
|
||||
@ -129,6 +129,22 @@ class WorkerBase:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def register_intermediate_hooks(self, config=None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is a stub for v0 workers. The actual implementation is
|
||||
in v1 workers. It's included here for compatibility with the
|
||||
collective_rpc mechanism.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging.
|
||||
"""
|
||||
logger.warning(
|
||||
"register_intermediate_hooks is not implemented in v0 workers. "
|
||||
"This is only available in v1 workers. No hooks will be registered."
|
||||
)
|
||||
return None
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up resources held by the worker."""
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user