Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e793b9a70c | |||
| 76c9ec0ddf | |||
| 87c737016d | |||
| ba90794ff1 | |||
| ab4ab0fd28 | |||
| 2af83ebdde | |||
| d8bff253d7 |
@ -149,25 +149,3 @@ steps:
|
|||||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|
||||||
- label: "Build and publish nightly multi-arch image to DockerHub"
|
|
||||||
depends_on:
|
|
||||||
- create-multi-arch-manifest
|
|
||||||
if: build.env("NIGHTLY") == "1"
|
|
||||||
agents:
|
|
||||||
queue: cpu_queue_postmerge
|
|
||||||
commands:
|
|
||||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
|
||||||
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
|
||||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly"
|
|
||||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
|
||||||
- "docker push vllm/vllm-openai:nightly"
|
|
||||||
- "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
|
||||||
# Clean up old nightly builds (keep only last 14)
|
|
||||||
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
|
|
||||||
plugins:
|
|
||||||
- docker-login#v3.0.0:
|
|
||||||
username: vllmbot
|
|
||||||
password-env: DOCKERHUB_TOKEN
|
|
||||||
env:
|
|
||||||
DOCKER_BUILDKIT: "1"
|
|
||||||
|
|||||||
@ -1,97 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# Clean up old nightly builds from DockerHub, keeping only the last 14 builds
|
|
||||||
# This script uses DockerHub API to list and delete old tags with "nightly-" prefix
|
|
||||||
|
|
||||||
# DockerHub API endpoint for vllm/vllm-openai repository
|
|
||||||
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
|
|
||||||
|
|
||||||
# Get DockerHub token from environment
|
|
||||||
if [ -z "$DOCKERHUB_TOKEN" ]; then
|
|
||||||
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Function to get all tags from DockerHub
|
|
||||||
get_all_tags() {
|
|
||||||
local page=1
|
|
||||||
local all_tags=""
|
|
||||||
|
|
||||||
while true; do
|
|
||||||
local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \
|
|
||||||
"$REPO_API_URL?page=$page&page_size=100")
|
|
||||||
|
|
||||||
# Get both last_updated timestamp and tag name, separated by |
|
|
||||||
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
|
|
||||||
|
|
||||||
if [ -z "$tags" ]; then
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
|
|
||||||
all_tags="$all_tags$tags"$'\n'
|
|
||||||
page=$((page + 1))
|
|
||||||
done
|
|
||||||
|
|
||||||
# Sort by timestamp (newest first) and extract just the tag names
|
|
||||||
echo "$all_tags" | sort -r | cut -d'|' -f2
|
|
||||||
}
|
|
||||||
|
|
||||||
delete_tag() {
|
|
||||||
local tag_name="$1"
|
|
||||||
echo "Deleting tag: $tag_name"
|
|
||||||
|
|
||||||
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
|
|
||||||
local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url")
|
|
||||||
|
|
||||||
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
|
|
||||||
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
|
|
||||||
else
|
|
||||||
echo "Successfully deleted tag: $tag_name"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first)
|
|
||||||
echo "Fetching all tags from DockerHub..."
|
|
||||||
all_tags=$(get_all_tags)
|
|
||||||
|
|
||||||
if [ -z "$all_tags" ]; then
|
|
||||||
echo "No tags found to clean up"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Count total tags
|
|
||||||
total_tags=$(echo "$all_tags" | wc -l)
|
|
||||||
echo "Found $total_tags tags"
|
|
||||||
|
|
||||||
# Keep only the last 14 builds (including the current one)
|
|
||||||
tags_to_keep=14
|
|
||||||
tags_to_delete=$((total_tags - tags_to_keep))
|
|
||||||
|
|
||||||
if [ $tags_to_delete -le 0 ]; then
|
|
||||||
echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep"
|
|
||||||
|
|
||||||
# Get tags to delete (skip the first $tags_to_keep tags)
|
|
||||||
tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1)))
|
|
||||||
|
|
||||||
if [ -z "$tags_to_delete_list" ]; then
|
|
||||||
echo "No tags to delete"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Delete old tags
|
|
||||||
echo "Deleting old tags..."
|
|
||||||
while IFS= read -r tag; do
|
|
||||||
if [ -n "$tag" ]; then
|
|
||||||
delete_tag "$tag"
|
|
||||||
# Add a small delay to avoid rate limiting
|
|
||||||
sleep 1
|
|
||||||
fi
|
|
||||||
done <<< "$tags_to_delete_list"
|
|
||||||
|
|
||||||
echo "Cleanup completed successfully"
|
|
||||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -273,20 +273,6 @@ pull_request_rules:
|
|||||||
users:
|
users:
|
||||||
- "sangstar"
|
- "sangstar"
|
||||||
|
|
||||||
- name: assign reviewer for modelopt changes
|
|
||||||
conditions:
|
|
||||||
- or:
|
|
||||||
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
|
|
||||||
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
|
|
||||||
- files~=^tests/models/quantization/test_modelopt\.py$
|
|
||||||
- files~=^tests/quantization/test_modelopt\.py$
|
|
||||||
- files~=^tests/models/quantization/test_nvfp4\.py$
|
|
||||||
- files~=^docs/features/quantization/modelopt\.md$
|
|
||||||
actions:
|
|
||||||
assign:
|
|
||||||
users:
|
|
||||||
- "Edwardf0t1"
|
|
||||||
|
|
||||||
- name: remove 'needs-rebase' label when conflict is resolved
|
- name: remove 'needs-rebase' label when conflict is resolved
|
||||||
conditions:
|
conditions:
|
||||||
- -conflict
|
- -conflict
|
||||||
|
|||||||
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||||
|
|||||||
65
docs/contributing/intermediate_logging.md
Normal file
65
docs/contributing/intermediate_logging.md
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Intermediate Tensor Logging
|
||||||
|
|
||||||
|
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The intermediate tensor logging feature enables you to:
|
||||||
|
|
||||||
|
- Log input and output tensors from a configured set of filters
|
||||||
|
- Filter modules by name using regex patterns
|
||||||
|
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
|
||||||
|
- Filter tensors by device
|
||||||
|
- Filter whole model fwd step id
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Enabling via parameters or config file
|
||||||
|
|
||||||
|
**Offline Inference example**
|
||||||
|
|
||||||
|
Dump all modules, all devices for step 0 (default behavior)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Dump first layers module, all devices for step 0
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Configuration Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Description | Default |
|
||||||
|
|-----------|------|-------------|---------|
|
||||||
|
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
|
||||||
|
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
|
||||||
|
| `log_step_ids` | array | List of step IDs to log | `[0]` |
|
||||||
|
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
|
||||||
|
| `device_names` | array | List of device names to log | `[]` (log all devices) |
|
||||||
|
|
||||||
|
### Output Directory Structure
|
||||||
|
|
||||||
|
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
|
||||||
|
|
||||||
|
```
|
||||||
|
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
|
||||||
|
└── step_0
|
||||||
|
├── model.embed_tokens
|
||||||
|
│ ├── inputs_0_cuda_0.pt
|
||||||
|
│ ├── inputs.json
|
||||||
|
│ ├── outputs_cuda_0.pt
|
||||||
|
│ └── outputs.json
|
||||||
|
├── model.layers.0.input_layernorm
|
||||||
|
│ ├── inputs_0_cuda_0.pt
|
||||||
|
│ ├── inputs.json
|
||||||
|
│ ├── outputs_cuda_0.pt
|
||||||
|
│ └── outputs.json
|
||||||
|
└── step_1/
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Each tensor is saved in a `.pt` file containing the full PyTorch tensors (can be loaded with `torch.load()`)
|
||||||
320
tests/v1/test_intermediates_logging.py
Normal file
320
tests/v1/test_intermediates_logging.py
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for the intermediate tensor logging functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import IntermediateLoggingConfig
|
||||||
|
from vllm.v1.intermediates.intermediates_logging import (
|
||||||
|
get_current_il_config, get_step, increment_step, intermediate_logging,
|
||||||
|
register_intermediate_hooks, reset_step, should_log_device,
|
||||||
|
should_log_module, should_log_step)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module):
|
||||||
|
"""A simple model for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(10, 20)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.linear2 = nn.Linear(20, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_output_dir():
|
||||||
|
"""Create a temporary directory for test outputs."""
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
yield temp_dir
|
||||||
|
# Clean up after the test
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_model():
|
||||||
|
"""Create a simple model for testing."""
|
||||||
|
return SimpleModel()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def il_config(temp_output_dir):
|
||||||
|
"""Create a basic IntermediateLoggingConfig for testing."""
|
||||||
|
return IntermediateLoggingConfig(output_dir=temp_output_dir,
|
||||||
|
enabled=True,
|
||||||
|
log_step_ids=[0, 1],
|
||||||
|
module_call_match=[".*linear.*"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_step_counter():
|
||||||
|
"""Test the step counter functionality."""
|
||||||
|
# Reset the step counter
|
||||||
|
reset_step()
|
||||||
|
assert get_step() == 0
|
||||||
|
|
||||||
|
# Increment the step counter
|
||||||
|
increment_step()
|
||||||
|
assert get_step() == 1
|
||||||
|
|
||||||
|
# Increment again
|
||||||
|
increment_step()
|
||||||
|
assert get_step() == 2
|
||||||
|
|
||||||
|
# Reset again
|
||||||
|
reset_step()
|
||||||
|
assert get_step() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_intermediate_logging_context_manager():
|
||||||
|
"""Test the intermediate_logging context manager."""
|
||||||
|
# Create a config
|
||||||
|
config = IntermediateLoggingConfig(enabled=True)
|
||||||
|
|
||||||
|
# Initially, there should be no global config
|
||||||
|
assert get_current_il_config() is None
|
||||||
|
|
||||||
|
# Use the context manager
|
||||||
|
with intermediate_logging(config):
|
||||||
|
# Inside the context, the global config should be set
|
||||||
|
assert get_current_il_config() is not None
|
||||||
|
assert get_current_il_config().enabled is True
|
||||||
|
|
||||||
|
# After the context, the global config should be None again
|
||||||
|
assert get_current_il_config() is None
|
||||||
|
|
||||||
|
# Test with a different config
|
||||||
|
config2 = IntermediateLoggingConfig(enabled=False)
|
||||||
|
with intermediate_logging(config2):
|
||||||
|
assert get_current_il_config() is not None
|
||||||
|
assert get_current_il_config().enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_log_step():
|
||||||
|
"""Test the should_log_step function."""
|
||||||
|
# Reset step counter
|
||||||
|
reset_step()
|
||||||
|
|
||||||
|
# Create configs with different step IDs
|
||||||
|
config_all_steps = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
log_step_ids=[] # Empty list means log all steps
|
||||||
|
)
|
||||||
|
config_specific_steps = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
|
||||||
|
)
|
||||||
|
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||||
|
log_step_ids=[0, 1, 2])
|
||||||
|
|
||||||
|
# Test with all steps config
|
||||||
|
with intermediate_logging(config_all_steps):
|
||||||
|
assert should_log_step(config_all_steps) is True # Step 0
|
||||||
|
increment_step()
|
||||||
|
assert should_log_step(config_all_steps) is True # Step 1
|
||||||
|
|
||||||
|
# Reset step counter
|
||||||
|
reset_step()
|
||||||
|
|
||||||
|
# Test with specific steps config
|
||||||
|
with intermediate_logging(config_specific_steps):
|
||||||
|
assert should_log_step(config_specific_steps) is True # Step 0
|
||||||
|
increment_step()
|
||||||
|
assert should_log_step(config_specific_steps) is False # Step 1
|
||||||
|
increment_step()
|
||||||
|
assert should_log_step(config_specific_steps) is True # Step 2
|
||||||
|
increment_step()
|
||||||
|
assert should_log_step(config_specific_steps) is False # Step 3
|
||||||
|
increment_step()
|
||||||
|
assert should_log_step(config_specific_steps) is True # Step 4
|
||||||
|
|
||||||
|
# Test with disabled config
|
||||||
|
with intermediate_logging(config_disabled):
|
||||||
|
assert should_log_step(config_disabled) is False # Disabled
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_log_device():
|
||||||
|
"""Test the should_log_device function."""
|
||||||
|
# Create configs with different device filters
|
||||||
|
config_all_devices = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
device_names=[] # Empty list means log all devices
|
||||||
|
)
|
||||||
|
config_specific_devices = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
|
||||||
|
)
|
||||||
|
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||||
|
device_names=["cuda:0", "cpu"])
|
||||||
|
|
||||||
|
# Test with all devices config
|
||||||
|
with intermediate_logging(config_all_devices):
|
||||||
|
assert should_log_device(config_all_devices, "cuda:0") is True
|
||||||
|
assert should_log_device(config_all_devices, "cuda:1") is True
|
||||||
|
assert should_log_device(config_all_devices, "cpu") is True
|
||||||
|
|
||||||
|
# Test with specific devices config
|
||||||
|
with intermediate_logging(config_specific_devices):
|
||||||
|
assert should_log_device(config_specific_devices, "cuda:0") is True
|
||||||
|
assert should_log_device(config_specific_devices, "cuda:1") is False
|
||||||
|
assert should_log_device(config_specific_devices, "cpu") is True
|
||||||
|
|
||||||
|
# Test with disabled config
|
||||||
|
with intermediate_logging(config_disabled):
|
||||||
|
assert should_log_device(config_disabled, "cuda:0") is False
|
||||||
|
assert should_log_device(config_disabled, "cpu") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_log_module(simple_model):
|
||||||
|
"""Test the should_log_module function."""
|
||||||
|
# Create configs with different module name filters
|
||||||
|
config_all_modules = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
module_call_match=None # None means log all modules
|
||||||
|
)
|
||||||
|
config_specific_modules = IntermediateLoggingConfig(
|
||||||
|
enabled=True,
|
||||||
|
module_call_match=[".*linear.*"
|
||||||
|
] # Only log modules with "linear" in the name
|
||||||
|
)
|
||||||
|
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||||
|
module_call_match=[".*"])
|
||||||
|
|
||||||
|
# Test with all modules config
|
||||||
|
with intermediate_logging(config_all_modules):
|
||||||
|
assert should_log_module(config_all_modules, "linear1",
|
||||||
|
simple_model.linear1) is True
|
||||||
|
assert should_log_module(config_all_modules, "relu",
|
||||||
|
simple_model.relu) is True
|
||||||
|
|
||||||
|
# Test with specific modules config
|
||||||
|
with intermediate_logging(config_specific_modules):
|
||||||
|
assert should_log_module(config_specific_modules, "linear1",
|
||||||
|
simple_model.linear1) is True
|
||||||
|
assert should_log_module(config_specific_modules, "relu",
|
||||||
|
simple_model.relu) is False
|
||||||
|
|
||||||
|
# Test with disabled config
|
||||||
|
with intermediate_logging(config_disabled):
|
||||||
|
assert should_log_module(config_disabled, "linear1",
|
||||||
|
simple_model.linear1) is False
|
||||||
|
assert should_log_module(config_disabled, "relu",
|
||||||
|
simple_model.relu) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_hooks(simple_model, il_config):
|
||||||
|
"""Test registering hooks on a model."""
|
||||||
|
# Register hooks
|
||||||
|
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||||
|
|
||||||
|
# Check that hooks were registered
|
||||||
|
assert len(logger_instance.hooks) > 0
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
logger_instance.remove_hooks()
|
||||||
|
|
||||||
|
# Check that hooks were removed
|
||||||
|
assert len(logger_instance.hooks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||||
|
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
|
||||||
|
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||||
|
il_config, temp_output_dir):
|
||||||
|
"""Test that forward hooks are called during model execution."""
|
||||||
|
mock_save_tensors.return_value = None
|
||||||
|
# Register hooks
|
||||||
|
with intermediate_logging(il_config):
|
||||||
|
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||||
|
|
||||||
|
# Create input tensor
|
||||||
|
input_tensor = torch.randn(2, 10)
|
||||||
|
|
||||||
|
# Reset step counter
|
||||||
|
reset_step()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
simple_model(input_tensor)
|
||||||
|
|
||||||
|
# Check that the step counter was incremented
|
||||||
|
assert get_step() == 1
|
||||||
|
|
||||||
|
# Check that dump_intermediates_to_json and save_tensors were called
|
||||||
|
assert mock_dump_json.called
|
||||||
|
assert mock_save_tensors.called
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
logger_instance.remove_hooks()
|
||||||
|
|
||||||
|
|
||||||
|
def test_end_to_end(simple_model, il_config, temp_output_dir):
|
||||||
|
"""Test the entire intermediate logging workflow end-to-end."""
|
||||||
|
# Register hooks
|
||||||
|
with intermediate_logging(il_config):
|
||||||
|
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||||
|
|
||||||
|
# Create input tensor
|
||||||
|
input_tensor = torch.randn(2, 10)
|
||||||
|
|
||||||
|
# Reset step counter
|
||||||
|
reset_step()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
simple_model(input_tensor)
|
||||||
|
|
||||||
|
# Check that output directories were created
|
||||||
|
root_dir = Path(il_config._output_run_dir)
|
||||||
|
assert root_dir.exists()
|
||||||
|
step_dir = root_dir / "step_0"
|
||||||
|
assert step_dir.exists()
|
||||||
|
|
||||||
|
module_dirs = list(step_dir.glob("*"))
|
||||||
|
print(f"{module_dirs=}")
|
||||||
|
assert len(module_dirs) > 0
|
||||||
|
|
||||||
|
# Check that input and output files were created
|
||||||
|
for module_dir in module_dirs:
|
||||||
|
print(f"{module_dir=}")
|
||||||
|
if os.path.isdir(module_dir):
|
||||||
|
inputs_json = module_dir / "inputs.json"
|
||||||
|
outputs_json = module_dir / "outputs.json"
|
||||||
|
|
||||||
|
# Check that JSON files exist
|
||||||
|
assert inputs_json.exists()
|
||||||
|
assert outputs_json.exists()
|
||||||
|
|
||||||
|
# Check that JSON files contain valid data
|
||||||
|
with open(inputs_json) as f:
|
||||||
|
inputs_data = json.load(f)
|
||||||
|
assert "type" in inputs_data
|
||||||
|
|
||||||
|
with open(outputs_json) as f:
|
||||||
|
outputs_data = json.load(f)
|
||||||
|
assert "type" in outputs_data
|
||||||
|
|
||||||
|
# Check that tensor files exist
|
||||||
|
tensor_files = list(module_dir.glob("*.pt"))
|
||||||
|
assert len(tensor_files) > 0
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
logger_instance.remove_hooks()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-xvs", __file__])
|
||||||
@ -3311,6 +3311,119 @@ class KVTransferConfig:
|
|||||||
return self.kv_connector_extra_config.get(key, default)
|
return self.kv_connector_extra_config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
|
class IntermediateLoggingConfig:
|
||||||
|
"""Configuration for intermediate tensor logging."""
|
||||||
|
|
||||||
|
output_dir: str = "/tmp/vllm_intermediates"
|
||||||
|
"""Directory where to save the intermediate tensors."""
|
||||||
|
|
||||||
|
module_call_match: Optional[list[str]] = None
|
||||||
|
"""Match modules by name regex and call index (
|
||||||
|
a module can be called multiple times in a step)
|
||||||
|
List of regex:call_idx, call_idx is -1 for default for all calls """
|
||||||
|
|
||||||
|
log_step_ids: list[int] = field(default_factory=lambda: [0])
|
||||||
|
"""List of step IDs to log (empty list means log all steps)."""
|
||||||
|
|
||||||
|
log_post_fwd_inputs: bool = False
|
||||||
|
"""Whether logging inputs after forwards for each module"""
|
||||||
|
|
||||||
|
max_tensor_size: Optional[int] = None
|
||||||
|
"""Maximum number of elements in tensors to log (None = no limit)."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
"""Whether logging is enabled."""
|
||||||
|
device_names: list[str] = field(default_factory=list)
|
||||||
|
"""List of device names to log (empty list means log all devices)."""
|
||||||
|
|
||||||
|
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
|
||||||
|
init=False)
|
||||||
|
"""Compiled regex patterns for module filtering."""
|
||||||
|
|
||||||
|
_module_call: dict[str, int] = field(default_factory=dict, init=False)
|
||||||
|
_step_id_set: set[int] = field(default_factory=set, init=False)
|
||||||
|
"""Set of step IDs for faster lookup."""
|
||||||
|
_output_run_dir: str = "/tmp/vllm_intermediates"
|
||||||
|
"""Unique directory to save single run/serve logging result."""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize derived fields after instance creation."""
|
||||||
|
self._compile_regex_patterns()
|
||||||
|
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
|
||||||
|
self._step_id_set = set(self.log_step_ids)
|
||||||
|
|
||||||
|
def _compile_regex_patterns(self):
|
||||||
|
"""Compile regex patterns for module name filtering."""
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
self._compiled_module_matches = []
|
||||||
|
|
||||||
|
if self.module_call_match is None:
|
||||||
|
logger.info(
|
||||||
|
"No module name regex patterns provided, will log all modules")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Compile all patterns
|
||||||
|
for regex_pattern_call_idx in self.module_call_match:
|
||||||
|
try:
|
||||||
|
splits = regex_pattern_call_idx.split(":", 2)
|
||||||
|
regex_pattern = splits[0]
|
||||||
|
call_idx = -1
|
||||||
|
if len(splits) > 1:
|
||||||
|
call_idx = int(splits[1])
|
||||||
|
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
|
||||||
|
self._compiled_module_calls[compiled_pattern] = call_idx
|
||||||
|
logger.info("Successfully compiled regex pattern: '%s'",
|
||||||
|
regex_pattern)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to parse module_call_match '%s': %s",
|
||||||
|
regex_pattern_call_idx, e)
|
||||||
|
|
||||||
|
logger.info("Compiled %d regex patterns",
|
||||||
|
len(self._compiled_module_calls))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert the config to a dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"output_run_dir": self.output_run_dir,
|
||||||
|
"module_call_match": self.module_call_match,
|
||||||
|
"log_step_ids": self.log_step_ids,
|
||||||
|
"max_tensor_size": self.max_tensor_size,
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"device_names": self.device_names
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
|
||||||
|
"""Parse the CLI value for the speculative config."""
|
||||||
|
return cls(**dict_value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_run_dir(self) -> str:
|
||||||
|
return self._output_run_dir
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# Intermediate logging doesn't affect the computation graph
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
@ -3362,6 +3475,8 @@ class VllmConfig:
|
|||||||
"""The configurations for distributed KV cache transfer."""
|
"""The configurations for distributed KV cache transfer."""
|
||||||
kv_events_config: Optional[KVEventsConfig] = None
|
kv_events_config: Optional[KVEventsConfig] = None
|
||||||
"""The configurations for event publishing."""
|
"""The configurations for event publishing."""
|
||||||
|
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
|
||||||
|
"""Configuration for intermediate tensor logging."""
|
||||||
# some opaque config, only used to provide additional information
|
# some opaque config, only used to provide additional information
|
||||||
# for the hash computation, mainly used for testing, debugging or out of
|
# for the hash computation, mainly used for testing, debugging or out of
|
||||||
# tree config registration.
|
# tree config registration.
|
||||||
@ -3446,6 +3561,10 @@ class VllmConfig:
|
|||||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
|
if self.intermediate_log_config:
|
||||||
|
vllm_factors.append(self.intermediate_log_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.additional_config:
|
if self.additional_config:
|
||||||
if isinstance(additional_config := self.additional_config, dict):
|
if isinstance(additional_config := self.additional_config, dict):
|
||||||
additional_config_hash = hashlib.md5(
|
additional_config_hash = hashlib.md5(
|
||||||
|
|||||||
@ -409,6 +409,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
speculative_config: Optional[Dict[str, Any]] = None
|
speculative_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
show_hidden_metrics_for_version: Optional[str] = \
|
show_hidden_metrics_for_version: Optional[str] = \
|
||||||
ObservabilityConfig.show_hidden_metrics_for_version
|
ObservabilityConfig.show_hidden_metrics_for_version
|
||||||
otlp_traces_endpoint: Optional[str] = \
|
otlp_traces_endpoint: Optional[str] = \
|
||||||
@ -456,6 +457,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||||
|
|
||||||
|
intermediate_log_config: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
kv_sharing_fast_prefill: bool = \
|
kv_sharing_fast_prefill: bool = \
|
||||||
CacheConfig.kv_sharing_fast_prefill
|
CacheConfig.kv_sharing_fast_prefill
|
||||||
|
|
||||||
@ -883,6 +886,9 @@ class EngineArgs:
|
|||||||
title="VllmConfig",
|
title="VllmConfig",
|
||||||
description=VllmConfig.__doc__,
|
description=VllmConfig.__doc__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vllm_group.add_argument("--intermediate-log-config",
|
||||||
|
**vllm_kwargs["intermediate_log_config"])
|
||||||
# We construct SpeculativeConfig using fields from other configs in
|
# We construct SpeculativeConfig using fields from other configs in
|
||||||
# create_engine_config. So we set the type to a JSON string here to
|
# create_engine_config. So we set the type to a JSON string here to
|
||||||
# delay the Pydantic validation that comes with SpeculativeConfig.
|
# delay the Pydantic validation that comes with SpeculativeConfig.
|
||||||
@ -1394,7 +1400,6 @@ class EngineArgs:
|
|||||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||||
collect_detailed_traces=self.collect_detailed_traces,
|
collect_detailed_traces=self.collect_detailed_traces,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
@ -1409,6 +1414,7 @@ class EngineArgs:
|
|||||||
compilation_config=self.compilation_config,
|
compilation_config=self.compilation_config,
|
||||||
kv_transfer_config=self.kv_transfer_config,
|
kv_transfer_config=self.kv_transfer_config,
|
||||||
kv_events_config=self.kv_events_config,
|
kv_events_config=self.kv_events_config,
|
||||||
|
intermediate_log_config=self.intermediate_log_config,
|
||||||
additional_config=self.additional_config,
|
additional_config=self.additional_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -59,14 +57,9 @@ class ConversationContext(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack) -> None:
|
||||||
request_id: str) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def cleanup_session(self) -> None:
|
|
||||||
raise NotImplementedError("Should not be called.")
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleContext(ConversationContext):
|
class SimpleContext(ConversationContext):
|
||||||
|
|
||||||
@ -96,13 +89,9 @@ class SimpleContext(ConversationContext):
|
|||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack) -> None:
|
||||||
request_id: str) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def cleanup_session(self) -> None:
|
|
||||||
raise NotImplementedError("Should not be called.")
|
|
||||||
|
|
||||||
|
|
||||||
class HarmonyContext(ConversationContext):
|
class HarmonyContext(ConversationContext):
|
||||||
|
|
||||||
@ -114,7 +103,6 @@ class HarmonyContext(ConversationContext):
|
|||||||
self._messages = messages
|
self._messages = messages
|
||||||
self.available_tools = available_tools
|
self.available_tools = available_tools
|
||||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||||
self.called_tools: set[str] = set()
|
|
||||||
|
|
||||||
self.parser = get_streamable_parser_for_assistant()
|
self.parser = get_streamable_parser_for_assistant()
|
||||||
self.num_init_messages = len(messages)
|
self.num_init_messages = len(messages)
|
||||||
@ -246,8 +234,7 @@ class HarmonyContext(ConversationContext):
|
|||||||
last_msg = self.messages[-1]
|
last_msg = self.messages[-1]
|
||||||
recipient = last_msg.recipient
|
recipient = last_msg.recipient
|
||||||
return recipient is not None and (recipient.startswith("browser.")
|
return recipient is not None and (recipient.startswith("browser.")
|
||||||
or recipient.startswith("python") or
|
or recipient.startswith("python"))
|
||||||
recipient.startswith("container."))
|
|
||||||
|
|
||||||
async def call_tool(self) -> list[Message]:
|
async def call_tool(self) -> list[Message]:
|
||||||
if not self.messages:
|
if not self.messages:
|
||||||
@ -261,9 +248,6 @@ class HarmonyContext(ConversationContext):
|
|||||||
elif recipient.startswith("python"):
|
elif recipient.startswith("python"):
|
||||||
return await self.call_python_tool(
|
return await self.call_python_tool(
|
||||||
self._tool_sessions["python"], last_msg)
|
self._tool_sessions["python"], last_msg)
|
||||||
elif recipient.startswith("container."):
|
|
||||||
return await self.call_container_tool(
|
|
||||||
self._tool_sessions["container"], last_msg)
|
|
||||||
raise ValueError("No tool call found")
|
raise ValueError("No tool call found")
|
||||||
|
|
||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
@ -272,7 +256,6 @@ class HarmonyContext(ConversationContext):
|
|||||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||||
Tool],
|
Tool],
|
||||||
last_msg: Message) -> list[Message]:
|
last_msg: Message) -> list[Message]:
|
||||||
self.called_tools.add("browser")
|
|
||||||
if isinstance(tool_session, Tool):
|
if isinstance(tool_session, Tool):
|
||||||
return await tool_session.get_result(self)
|
return await tool_session.get_result(self)
|
||||||
tool_name = last_msg.recipient.split(".")[1]
|
tool_name = last_msg.recipient.split(".")[1]
|
||||||
@ -282,16 +265,12 @@ class HarmonyContext(ConversationContext):
|
|||||||
content = TextContent(text=result_str)
|
content = TextContent(text=result_str)
|
||||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||||
return [
|
return [
|
||||||
Message(author=author,
|
Message(author=author, content=[content], recipient=Role.ASSISTANT)
|
||||||
content=[content],
|
|
||||||
recipient=Role.ASSISTANT,
|
|
||||||
channel=last_msg.channel)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||||
Tool],
|
Tool],
|
||||||
last_msg: Message) -> list[Message]:
|
last_msg: Message) -> list[Message]:
|
||||||
self.called_tools.add("python")
|
|
||||||
if isinstance(tool_session, Tool):
|
if isinstance(tool_session, Tool):
|
||||||
return await tool_session.get_result(self)
|
return await tool_session.get_result(self)
|
||||||
param = {
|
param = {
|
||||||
@ -311,63 +290,13 @@ class HarmonyContext(ConversationContext):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack) -> None:
|
||||||
request_id: str) -> None:
|
|
||||||
if tool_server:
|
if tool_server:
|
||||||
for tool_name in self.available_tools:
|
for tool_name in self.available_tools:
|
||||||
if tool_name not in self._tool_sessions:
|
if tool_name not in self._tool_sessions:
|
||||||
tool_session = await exit_stack.enter_async_context(
|
self._tool_sessions[
|
||||||
tool_server.new_session(tool_name, request_id))
|
tool_name] = await exit_stack.enter_async_context(
|
||||||
self._tool_sessions[tool_name] = tool_session
|
tool_server.new_session(tool_name))
|
||||||
exit_stack.push_async_exit(self.cleanup_session)
|
|
||||||
|
|
||||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
|
||||||
Tool],
|
|
||||||
last_msg: Message) -> list[Message]:
|
|
||||||
"""
|
|
||||||
Call container tool. Expect this to be run in a stateful docker
|
|
||||||
with command line terminal.
|
|
||||||
The official container tool would at least
|
|
||||||
expect the following format:
|
|
||||||
- for tool name: exec
|
|
||||||
- args:
|
|
||||||
{
|
|
||||||
"cmd":List[str] "command to execute",
|
|
||||||
"workdir":optional[str] "current working directory",
|
|
||||||
"env":optional[object/dict] "environment variables",
|
|
||||||
"session_name":optional[str] "session name",
|
|
||||||
"timeout":optional[int] "timeout in seconds",
|
|
||||||
"user":optional[str] "user name",
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
self.called_tools.add("container")
|
|
||||||
if isinstance(tool_session, Tool):
|
|
||||||
return await tool_session.get_result(self)
|
|
||||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
|
||||||
args = json.loads(last_msg.content[0].text)
|
|
||||||
result = await tool_session.call_tool(tool_name, args)
|
|
||||||
result_str = result.content[0].text
|
|
||||||
content = TextContent(text=result_str)
|
|
||||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
|
||||||
return [
|
|
||||||
Message(author=author,
|
|
||||||
content=[content],
|
|
||||||
recipient=Role.ASSISTANT,
|
|
||||||
channel=last_msg.channel)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
|
||||||
"""Can be used as coro to used in __aexit__"""
|
|
||||||
|
|
||||||
async def cleanup_tool_session(tool_session):
|
|
||||||
if not isinstance(tool_session, Tool):
|
|
||||||
logger.info("Cleaning up tool session for %s",
|
|
||||||
tool_session._client_info)
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
await tool_session.call_tool("cleanup_session", {})
|
|
||||||
|
|
||||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
|
||||||
for tool in self.called_tools))
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingHarmonyContext(HarmonyContext):
|
class StreamingHarmonyContext(HarmonyContext):
|
||||||
|
|||||||
@ -16,13 +16,11 @@ from openai.types.responses.response_function_web_search import (
|
|||||||
from openai.types.responses.response_reasoning_item import (
|
from openai.types.responses.response_reasoning_item import (
|
||||||
Content as ResponseReasoningTextContent)
|
Content as ResponseReasoningTextContent)
|
||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
from openai_harmony import (Author, ChannelConfig, Conversation,
|
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||||
DeveloperContent, HarmonyEncodingName, Message,
|
HarmonyEncodingName, Message, ReasoningEffort,
|
||||||
ReasoningEffort, Role, StreamableParser,
|
Role, StreamableParser, SystemContent, TextContent,
|
||||||
SystemContent, TextContent, ToolDescription,
|
ToolDescription, load_harmony_encoding)
|
||||||
load_harmony_encoding)
|
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||||
ResponseInputOutputItem)
|
ResponseInputOutputItem)
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -35,20 +33,6 @@ REASONING_EFFORT = {
|
|||||||
|
|
||||||
_harmony_encoding = None
|
_harmony_encoding = None
|
||||||
|
|
||||||
# Builtin tools that should be included in the system message when
|
|
||||||
# they are available and requested by the user.
|
|
||||||
# Tool args are provided by MCP tool descriptions. Output
|
|
||||||
# of the tools are stringified.
|
|
||||||
BUILTIN_TOOLS = {
|
|
||||||
"web_search_preview",
|
|
||||||
"code_interpreter",
|
|
||||||
"container",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def has_custom_tools(tool_types: list[str]) -> bool:
|
|
||||||
return not set(tool_types).issubset(BUILTIN_TOOLS)
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoding():
|
def get_encoding():
|
||||||
global _harmony_encoding
|
global _harmony_encoding
|
||||||
@ -64,19 +48,10 @@ def get_system_message(
|
|||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
browser_description: Optional[str] = None,
|
browser_description: Optional[str] = None,
|
||||||
python_description: Optional[str] = None,
|
python_description: Optional[str] = None,
|
||||||
container_description: Optional[str] = None,
|
|
||||||
instructions: Optional[str] = None,
|
|
||||||
with_custom_tools: bool = False,
|
|
||||||
) -> Message:
|
) -> Message:
|
||||||
sys_msg_content = SystemContent.new()
|
sys_msg_content = SystemContent.new()
|
||||||
if model_identity is not None:
|
if model_identity is not None:
|
||||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||||
if (instructions is not None
|
|
||||||
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
|
||||||
current_identity = sys_msg_content.model_identity
|
|
||||||
new_identity = (f'{current_identity}\n{instructions}'
|
|
||||||
if current_identity else instructions)
|
|
||||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
|
||||||
if reasoning_effort is not None:
|
if reasoning_effort is not None:
|
||||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||||
REASONING_EFFORT[reasoning_effort])
|
REASONING_EFFORT[reasoning_effort])
|
||||||
@ -88,14 +63,6 @@ def get_system_message(
|
|||||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||||
if python_description is not None:
|
if python_description is not None:
|
||||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||||
if container_description is not None:
|
|
||||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
|
||||||
if not with_custom_tools:
|
|
||||||
channel_config = sys_msg_content.channel_config
|
|
||||||
invalid_channel = "commentary"
|
|
||||||
new_config = ChannelConfig.require_channels(
|
|
||||||
[c for c in channel_config.valid_channels if c != invalid_channel])
|
|
||||||
sys_msg_content = sys_msg_content.with_channel_config(new_config)
|
|
||||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||||
return sys_msg
|
return sys_msg
|
||||||
|
|
||||||
@ -119,17 +86,14 @@ def get_developer_message(
|
|||||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
dev_msg_content = DeveloperContent.new()
|
dev_msg_content = DeveloperContent.new()
|
||||||
if (instructions is not None
|
if instructions is not None:
|
||||||
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
|
||||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.type in ("web_search_preview", "code_interpreter",
|
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||||
"container"):
|
|
||||||
# These are built-in tools that are added to the system message.
|
# These are built-in tools that are added to the system message.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif tool.type == "function":
|
elif tool.type == "function":
|
||||||
function_tools.append(tool)
|
function_tools.append(tool)
|
||||||
else:
|
else:
|
||||||
@ -172,8 +136,6 @@ def parse_response_input(
|
|||||||
TextContent(text=text_prefix + c["text"]) for c in content
|
TextContent(text=text_prefix + c["text"]) for c in content
|
||||||
]
|
]
|
||||||
msg = Message.from_role_and_contents(role, contents)
|
msg = Message.from_role_and_contents(role, contents)
|
||||||
if role == "assistant":
|
|
||||||
msg = msg.with_channel("final")
|
|
||||||
elif response_msg["type"] == "function_call_output":
|
elif response_msg["type"] == "function_call_output":
|
||||||
call_id = response_msg["call_id"]
|
call_id = response_msg["call_id"]
|
||||||
call_response: Optional[ResponseFunctionToolCall] = None
|
call_response: Optional[ResponseFunctionToolCall] = None
|
||||||
|
|||||||
@ -44,9 +44,8 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext,
|
|||||||
SimpleContext, StreamingHarmonyContext)
|
SimpleContext, StreamingHarmonyContext)
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||||
get_system_message, get_user_message, has_custom_tools,
|
get_system_message, get_user_message, parse_output_message,
|
||||||
parse_output_message, parse_remaining_state, parse_response_input,
|
parse_remaining_state, parse_response_input, render_for_completion)
|
||||||
render_for_completion)
|
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -267,8 +266,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
builtin_tool_list.append("browser")
|
builtin_tool_list.append("browser")
|
||||||
if self.tool_server.has_tool("python"):
|
if self.tool_server.has_tool("python"):
|
||||||
builtin_tool_list.append("python")
|
builtin_tool_list.append("python")
|
||||||
if self.tool_server.has_tool("container"):
|
|
||||||
builtin_tool_list.append("container")
|
|
||||||
|
|
||||||
if self.tool_server is not None:
|
if self.tool_server is not None:
|
||||||
available_tools = builtin_tool_list
|
available_tools = builtin_tool_list
|
||||||
@ -451,8 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
request.request_id)
|
|
||||||
async for _ in result_generator:
|
async for _ in result_generator:
|
||||||
pass
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -714,21 +710,13 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# New conversation.
|
# New conversation.
|
||||||
reasoning_effort = (request.reasoning.effort
|
reasoning_effort = (request.reasoning.effort
|
||||||
if request.reasoning else None)
|
if request.reasoning else None)
|
||||||
# Temporary: OpenAI types doesn't have container tool
|
|
||||||
# so we used MCP to cover that, up for change
|
|
||||||
tool_types = [tool.type for tool in request.tools]
|
tool_types = [tool.type for tool in request.tools]
|
||||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
|
||||||
tool_types.append("container")
|
|
||||||
enable_browser = ("web_search_preview" in tool_types
|
enable_browser = ("web_search_preview" in tool_types
|
||||||
and self.tool_server is not None
|
and self.tool_server is not None
|
||||||
and self.tool_server.has_tool("browser"))
|
and self.tool_server.has_tool("browser"))
|
||||||
enable_code_interpreter = ("code_interpreter" in tool_types
|
enable_code_interpreter = ("code_interpreter" in tool_types
|
||||||
and self.tool_server is not None
|
and self.tool_server is not None
|
||||||
and self.tool_server.has_tool("python"))
|
and self.tool_server.has_tool("python"))
|
||||||
enable_container = ("container" in tool_types
|
|
||||||
and self.tool_server is not None
|
|
||||||
and self.tool_server.has_tool("container"))
|
|
||||||
with_custom_tools = has_custom_tools(tool_types)
|
|
||||||
sys_msg = get_system_message(
|
sys_msg = get_system_message(
|
||||||
reasoning_effort=reasoning_effort,
|
reasoning_effort=reasoning_effort,
|
||||||
browser_description=self.tool_server.get_tool_description(
|
browser_description=self.tool_server.get_tool_description(
|
||||||
@ -737,17 +725,11 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
python_description=self.tool_server.get_tool_description(
|
python_description=self.tool_server.get_tool_description(
|
||||||
"python") if enable_code_interpreter
|
"python") if enable_code_interpreter
|
||||||
and self.tool_server is not None else None,
|
and self.tool_server is not None else None,
|
||||||
container_description=self.tool_server.get_tool_description(
|
|
||||||
"container")
|
|
||||||
if enable_container and self.tool_server is not None else None,
|
|
||||||
instructions=request.instructions,
|
|
||||||
with_custom_tools=with_custom_tools,
|
|
||||||
)
|
)
|
||||||
messages.append(sys_msg)
|
messages.append(sys_msg)
|
||||||
if with_custom_tools:
|
dev_msg = get_developer_message(request.instructions,
|
||||||
dev_msg = get_developer_message(
|
request.tools)
|
||||||
instructions=request.instructions, tools=request.tools)
|
messages.append(dev_msg)
|
||||||
messages.append(dev_msg)
|
|
||||||
else:
|
else:
|
||||||
# Continue the previous conversation.
|
# Continue the previous conversation.
|
||||||
# FIXME(woosuk): Currently, request params like reasoning and
|
# FIXME(woosuk): Currently, request params like reasoning and
|
||||||
@ -1631,8 +1613,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
processer = None
|
processer = None
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
request.request_id)
|
|
||||||
processer = self._process_harmony_streaming_events
|
processer = self._process_harmony_streaming_events
|
||||||
else:
|
else:
|
||||||
processer = self._process_simple_streaming_events
|
processer = self._process_simple_streaming_events
|
||||||
|
|||||||
@ -86,8 +86,7 @@ class ToolServer(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def new_session(self, tool_name: str,
|
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
|
||||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
|
||||||
"""
|
"""
|
||||||
Create a session for the tool.
|
Create a session for the tool.
|
||||||
"""
|
"""
|
||||||
@ -125,8 +124,7 @@ class MCPToolServer(ToolServer):
|
|||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=tool.inputSchema)
|
parameters=tool.inputSchema)
|
||||||
for tool in list_tools_response.tools
|
for tool in list_tools_response.tools
|
||||||
],
|
])
|
||||||
)
|
|
||||||
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||||
if tool_from_mcp.name not in self.urls:
|
if tool_from_mcp.name not in self.urls:
|
||||||
self.urls[tool_from_mcp.name] = url
|
self.urls[tool_from_mcp.name] = url
|
||||||
@ -144,16 +142,14 @@ class MCPToolServer(ToolServer):
|
|||||||
return self.harmony_tool_descriptions.get(tool_name)
|
return self.harmony_tool_descriptions.get(tool_name)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_session(self, tool_name: str, session_id: str):
|
async def new_session(self, tool_name: str):
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
url = self.urls.get(tool_name)
|
url = self.urls.get(tool_name)
|
||||||
headers = {"x-session-id": session_id}
|
|
||||||
if not url:
|
if not url:
|
||||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
async with sse_client(url=url,
|
async with sse_client(url=url) as streams, ClientSession(
|
||||||
headers=headers) as streams, ClientSession(
|
*streams) as session:
|
||||||
*streams) as session:
|
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
@ -186,7 +182,7 @@ class DemoToolServer(ToolServer):
|
|||||||
raise ValueError(f"Unknown tool {tool_name}")
|
raise ValueError(f"Unknown tool {tool_name}")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_session(self, tool_name: str, session_id: str):
|
async def new_session(self, tool_name: str):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
yield self.tools[tool_name]
|
yield self.tools[tool_name]
|
||||||
|
|||||||
11
vllm/envs.py
11
vllm/envs.py
@ -168,8 +168,6 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
|
||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -1203,15 +1201,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_TUNED_CONFIG_FOLDER":
|
"VLLM_TUNED_CONFIG_FOLDER":
|
||||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||||
|
|
||||||
# Allows vllm use container tool
|
|
||||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
|
||||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
|
||||||
|
|
||||||
# Allows harmony instructions to be injected on system messages
|
|
||||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
|
||||||
lambda: bool(
|
|
||||||
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))),
|
|
||||||
|
|
||||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||||
round_up)
|
round_up)
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
@ -786,7 +786,6 @@ class FusedMoE(CustomOp):
|
|||||||
enable_eplb: bool = False,
|
enable_eplb: bool = False,
|
||||||
num_redundant_experts: int = 0,
|
num_redundant_experts: int = 0,
|
||||||
has_bias: bool = False,
|
has_bias: bool = False,
|
||||||
is_sequence_parallel=False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
@ -798,10 +797,6 @@ class FusedMoE(CustomOp):
|
|||||||
dp_size_ = (dp_size
|
dp_size_ = (dp_size
|
||||||
if dp_size is not None else get_dp_group().world_size)
|
if dp_size is not None else get_dp_group().world_size)
|
||||||
|
|
||||||
self.is_sequence_parallel = is_sequence_parallel
|
|
||||||
if self.is_sequence_parallel:
|
|
||||||
self.sp_size = tp_size_
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||||
FusedMoEParallelConfig.make(
|
FusedMoEParallelConfig.make(
|
||||||
@ -1704,22 +1699,14 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
ctx = get_forward_context()
|
ctx = get_forward_context()
|
||||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||||
|
|
||||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
|
||||||
# to find the maximum number of tokens for any individual dispatcher.
|
|
||||||
if self.is_sequence_parallel:
|
|
||||||
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
|
|
||||||
self.sp_size)
|
|
||||||
|
|
||||||
num_tokens = full_hidden_states.size(0)
|
num_tokens = full_hidden_states.size(0)
|
||||||
for chunk_idx, chunk_start_ in enumerate(
|
for chunk_idx, chunk_start_ in enumerate(
|
||||||
range(0, max_tokens_across_dispatchers,
|
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||||
moe_dp_chunk_size_per_rank)):
|
|
||||||
chunk_start = chunk_start_
|
chunk_start = chunk_start_
|
||||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||||
max_tokens_across_dispatchers)
|
max_tokens_across_dp)
|
||||||
# clamp start and end
|
# clamp start and end
|
||||||
chunk_start = min(chunk_start, num_tokens - 1)
|
chunk_start = min(chunk_start, num_tokens - 1)
|
||||||
chunk_end = min(chunk_end, num_tokens)
|
chunk_end = min(chunk_end, num_tokens)
|
||||||
|
|||||||
@ -37,6 +37,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config. \
|
self.config = vllm_config. \
|
||||||
speculative_config.draft_model_config.hf_config
|
speculative_config.draft_model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.vocab_size = self.config.vocab_size
|
self.vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
@ -49,8 +51,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
DeepseekV2DecoderLayer(
|
DeepseekV2DecoderLayer(
|
||||||
vllm_config,
|
self.config,
|
||||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
) for i in range(self.config.num_hidden_layers)
|
) for i in range(self.config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -43,19 +43,23 @@ class SharedHead(nn.Module):
|
|||||||
|
|
||||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
prefix: str,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
|
||||||
|
cache_config, quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -91,8 +95,13 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
# to map the exact layer index from weights
|
# to map the exact layer index from weights
|
||||||
self.layers = torch.nn.ModuleDict({
|
self.layers = torch.nn.ModuleDict({
|
||||||
str(idx):
|
str(idx):
|
||||||
DeepSeekMultiTokenPredictorLayer(vllm_config,
|
DeepSeekMultiTokenPredictorLayer(
|
||||||
f"{prefix}.layers.{idx}")
|
config,
|
||||||
|
f"{prefix}.layers.{idx}",
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
cache_config=vllm_config.cache_config,
|
||||||
|
quant_config=vllm_config.quant_config,
|
||||||
|
)
|
||||||
for idx in range(self.mtp_start_layer_idx,
|
for idx in range(self.mtp_start_layer_idx,
|
||||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -32,14 +32,12 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||||
|
get_current_vllm_config)
|
||||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_world_size)
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_gather)
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -57,9 +55,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
@ -76,27 +72,19 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
is_sequence_parallel=False,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# If is_sequence_parallel, the input and output tensors are sharded
|
|
||||||
# across the ranks within the tp_group. In this case the weights are
|
|
||||||
# replicated and no collective ops are needed.
|
|
||||||
# Otherwise we use standard TP with an allreduce at the end.
|
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2,
|
hidden_size, [intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
disable_tp=is_sequence_parallel,
|
|
||||||
prefix=f"{prefix}.gate_up_proj")
|
prefix=f"{prefix}.gate_up_proj")
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
disable_tp=is_sequence_parallel,
|
|
||||||
prefix=f"{prefix}.down_proj")
|
prefix=f"{prefix}.down_proj")
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
@ -110,58 +98,17 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Chunk x along the num_tokens axis for sequence parallelism
|
|
||||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
|
||||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
|
||||||
# even though we explicitly pad to avoid this.
|
|
||||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# all_gather needs the sequence length to be divisible by tp_size
|
|
||||||
seq_len = x.size(0)
|
|
||||||
remainder = seq_len % tp_size
|
|
||||||
if remainder != 0:
|
|
||||||
pad_len = tp_size - remainder
|
|
||||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
|
||||||
|
|
||||||
chunk = x.shape[0] // tp_size
|
|
||||||
start = tp_rank * chunk
|
|
||||||
return torch.narrow(x, 0, start, chunk)
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
seq_len = cdiv(x.size(0), tp_size)
|
|
||||||
shape = list(x.shape)
|
|
||||||
shape[0] = seq_len
|
|
||||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="sequence_parallel_chunk",
|
|
||||||
op_func=sequence_parallel_chunk,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=sequence_parallel_chunk_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
enable_eplb: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
self.ep_group = get_ep_group().device_group
|
self.ep_group = get_ep_group().device_group
|
||||||
@ -170,21 +117,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.n_routed_experts: int = config.n_routed_experts
|
self.n_routed_experts: int = config.n_routed_experts
|
||||||
self.n_shared_experts: int = config.n_shared_experts
|
self.n_shared_experts: int = config.n_shared_experts
|
||||||
|
|
||||||
# The all_reduce at the end of attention (during o_proj) means that
|
|
||||||
# inputs are replicated across each rank of the tensor parallel group.
|
|
||||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
|
||||||
# tokens results in useless duplicate computation and communication.
|
|
||||||
#
|
|
||||||
# In this case, ensure the input to the experts is sequence parallel
|
|
||||||
# to avoid the excess work.
|
|
||||||
#
|
|
||||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
|
||||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
|
||||||
in ("deepep_high_throughput",
|
|
||||||
"deepep_low_latency")
|
|
||||||
and parallel_config.enable_expert_parallel
|
|
||||||
and self.tp_size > 1)
|
|
||||||
|
|
||||||
if config.hidden_act != "silu":
|
if config.hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@ -201,8 +133,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.gate.e_score_correction_bias = None
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
# Load balancing settings.
|
# Load balancing settings.
|
||||||
eplb_config = parallel_config.eplb_config
|
vllm_config = get_current_vllm_config()
|
||||||
self.enable_eplb = parallel_config.enable_eplb
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
|
self.enable_eplb = enable_eplb
|
||||||
|
|
||||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||||
self.n_logical_experts = self.n_routed_experts
|
self.n_logical_experts = self.n_routed_experts
|
||||||
@ -233,9 +166,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
routed_scaling_factor=1.0,
|
routed_scaling_factor=1.0,
|
||||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts)
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
|
||||||
)
|
|
||||||
self.shared_experts = None
|
self.shared_experts = None
|
||||||
else:
|
else:
|
||||||
intermediate_size = (config.moe_intermediate_size *
|
intermediate_size = (config.moe_intermediate_size *
|
||||||
@ -246,7 +177,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
)
|
)
|
||||||
@ -269,22 +199,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
routed_scaling_factor=1.0,
|
routed_scaling_factor=1.0,
|
||||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts)
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
# Chunk the hidden states so they aren't replicated across TP ranks.
|
|
||||||
# This avoids duplicate computation in self.experts.
|
|
||||||
# TODO: We can replace the all_reduce at the end of attn with a
|
|
||||||
# reduce_scatter instead of chunking here.
|
|
||||||
if self.is_sequence_parallel:
|
|
||||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
@ -309,11 +228,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
assert shared_output is not None
|
assert shared_output is not None
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
|
|
||||||
if self.is_sequence_parallel:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_gather(
|
|
||||||
final_hidden_states, 0)
|
|
||||||
final_hidden_states = final_hidden_states[:num_tokens]
|
|
||||||
elif self.tp_size > 1:
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||||
final_hidden_states))
|
final_hidden_states))
|
||||||
@ -617,15 +532,16 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||||
|
prefix: str,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
model_config = vllm_config.model_config
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
@ -662,9 +578,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
and layer_idx % config.moe_layer_freq == 0):
|
and layer_idx % config.moe_layer_freq == 0):
|
||||||
self.mlp = DeepseekV2MoE(
|
self.mlp = DeepseekV2MoE(
|
||||||
config=config,
|
config=config,
|
||||||
parallel_config=parallel_config,
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
@ -734,7 +650,10 @@ class DeepseekV2Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@ -750,7 +669,14 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
lambda prefix: DeepseekV2DecoderLayer(
|
||||||
|
config,
|
||||||
|
prefix,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
|
|||||||
@ -80,6 +80,10 @@ class EngineCore:
|
|||||||
|
|
||||||
# Setup Model.
|
# Setup Model.
|
||||||
self.model_executor = executor_class(vllm_config)
|
self.model_executor = executor_class(vllm_config)
|
||||||
|
if vllm_config.intermediate_log_config is not None:
|
||||||
|
self.collective_rpc("register_intermediate_hooks",
|
||||||
|
args=(vllm_config.intermediate_log_config, ))
|
||||||
|
|
||||||
if executor_fail_callback is not None:
|
if executor_fail_callback is not None:
|
||||||
self.model_executor.register_failure_callback(
|
self.model_executor.register_failure_callback(
|
||||||
executor_fail_callback)
|
executor_fail_callback)
|
||||||
|
|||||||
@ -641,13 +641,7 @@ class WorkerProc:
|
|||||||
|
|
||||||
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||||
"""Main busy loop for Multiprocessing Workers"""
|
"""Main busy loop for Multiprocessing Workers"""
|
||||||
import os, psutil
|
|
||||||
p = psutil.Process(os.getpid())
|
|
||||||
i = 0
|
|
||||||
while True:
|
while True:
|
||||||
if i % 100 == 0:
|
|
||||||
logger.info("WorkerProc RSS MB: %d", p.memory_info().rss/1024/1024)
|
|
||||||
i += 1
|
|
||||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||||
cancel=cancel)
|
cancel=cancel)
|
||||||
try:
|
try:
|
||||||
|
|||||||
0
vllm/v1/intermediates/__init__.py
Normal file
0
vllm/v1/intermediates/__init__.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
@ -0,0 +1,405 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Module for logging intermediate tensors during model execution.
|
||||||
|
|
||||||
|
This module provides functionality to capture and save intermediate tensors
|
||||||
|
(inputs and outputs) from PyTorch modules during forward passes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
from vllm.config import IntermediateLoggingConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Global step counter
|
||||||
|
_CURRENT_STEP = 0
|
||||||
|
|
||||||
|
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
|
||||||
|
|
||||||
|
IL_MODULE_NAME = "_il_module_name"
|
||||||
|
IL_MODULE_CALL_IDX = "_il_module_call_idx"
|
||||||
|
|
||||||
|
# Utility functions for intermediate logging
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_step(config):
|
||||||
|
"""Check if the current step should be logged based on the step IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The IntermediateLoggingConfig instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the current step should be logged, False otherwise.
|
||||||
|
"""
|
||||||
|
if not is_log_enabled(config):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If log_step_ids is empty, log all steps
|
||||||
|
if not config.log_step_ids:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Otherwise, check if current step is in the set of step IDs to log
|
||||||
|
return get_step() in config._step_id_set
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_device(config, device_name):
|
||||||
|
"""Check if a device should be logged based on the device names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The IntermediateLoggingConfig instance.
|
||||||
|
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the device should be logged, False otherwise.
|
||||||
|
If device_names is empty, all devices are logged.
|
||||||
|
"""
|
||||||
|
if not is_log_enabled(config):
|
||||||
|
return False
|
||||||
|
# If device_names is empty, log all devices
|
||||||
|
if not config.device_names:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Otherwise, check if device_name is in the list of device names to log
|
||||||
|
return device_name in config.device_names
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||||
|
"""Check if a module should be logged based on the name regex patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The IntermediateLoggingConfig instance.
|
||||||
|
module_name: The name of the module to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the module should be logged, False otherwise.
|
||||||
|
If no patterns are defined, all modules are logged.
|
||||||
|
If patterns are defined, the module is logged if it matches ANY pattern.
|
||||||
|
"""
|
||||||
|
if not is_log_enabled(config):
|
||||||
|
return False
|
||||||
|
# If no patterns are defined, log all modules
|
||||||
|
if not config._compiled_module_calls:
|
||||||
|
set_il_module_name(module, module_name)
|
||||||
|
set_il_module_call_idx(module, -1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if the module name matches any of the patterns
|
||||||
|
for pattern, call_idx in config._compiled_module_calls.items():
|
||||||
|
match = pattern.search(module_name)
|
||||||
|
if match:
|
||||||
|
logger.debug(
|
||||||
|
"Module %s, %s matches pattern: '%s', call_idx=%s",
|
||||||
|
module_name,
|
||||||
|
module.__class__.__name__,
|
||||||
|
pattern.pattern,
|
||||||
|
call_idx,
|
||||||
|
)
|
||||||
|
set_il_module_name(module, module_name)
|
||||||
|
set_il_module_call_idx(module, call_idx)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_log_enabled(config):
|
||||||
|
if not config or not config.enabled:
|
||||||
|
return False
|
||||||
|
if torch.compiler.is_compiling():
|
||||||
|
logger.debug("Not logging because torch.compile is in progress")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_il_module_name(module: torch.nn.Module) -> str:
|
||||||
|
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_il_module_call_idx(module: torch.nn.Module) -> int:
|
||||||
|
return getattr(module, IL_MODULE_CALL_IDX, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
|
||||||
|
setattr(module, IL_MODULE_NAME, name)
|
||||||
|
|
||||||
|
|
||||||
|
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
|
||||||
|
setattr(module, IL_MODULE_CALL_IDX, idx)
|
||||||
|
|
||||||
|
|
||||||
|
_global_config: Optional[IntermediateLoggingConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
|
||||||
|
"""
|
||||||
|
Temporarily sets the global config for the duration of the context.
|
||||||
|
:param config: Keyword arguments to set as global config
|
||||||
|
"""
|
||||||
|
global _global_config
|
||||||
|
old_config = _global_config
|
||||||
|
try:
|
||||||
|
_global_config = config
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_global_config = old_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_il_config():
|
||||||
|
return _global_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_tensors(tensor: Any, file_path: str) -> Any:
|
||||||
|
"""Utility function to dump tensor to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
|
||||||
|
tensors, or a dictionary containing tensors.
|
||||||
|
file_path: Base path where to save the tensor (without extension).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
device_name = str(tensor.device)
|
||||||
|
intermediate_log_config = get_current_il_config()
|
||||||
|
if not should_log_device(intermediate_log_config, device_name):
|
||||||
|
return tensor
|
||||||
|
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
|
||||||
|
try:
|
||||||
|
torch.save(tensor, pt_path)
|
||||||
|
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
|
||||||
|
pt_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
if isinstance(tensor, (list, tuple)):
|
||||||
|
for i, item in enumerate(tensor):
|
||||||
|
save_tensors(item, f"{file_path}_{i}")
|
||||||
|
return tensor
|
||||||
|
if isinstance(tensor, dict):
|
||||||
|
for k, v in tensor.items():
|
||||||
|
save_tensors(v, f"{file_path}_{k}")
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||||
|
outputs: Any) -> None:
|
||||||
|
"""Hook to increment the global step counter after a forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The PyTorch module being executed.
|
||||||
|
inputs: The inputs to the module's forward function.
|
||||||
|
outputs: The outputs from the module's forward function.
|
||||||
|
"""
|
||||||
|
if get_current_il_config() is None:
|
||||||
|
return
|
||||||
|
# Increment the global step counter
|
||||||
|
increment_step()
|
||||||
|
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||||
|
_CURRENT_STEP_MODULE_CALL_STEP = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_module_log_dir(
|
||||||
|
intermediate_log_config: IntermediateLoggingConfig,
|
||||||
|
module_name: str,
|
||||||
|
is_pre_fwd: bool = False,
|
||||||
|
) -> Path:
|
||||||
|
# Create a unique directory for this step if not
|
||||||
|
dump_dir = Path(
|
||||||
|
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||||
|
dump_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Create module directory
|
||||||
|
suffix = ""
|
||||||
|
module_call_idx = get_current_step_module_call(module_name)
|
||||||
|
if module_call_idx > 0:
|
||||||
|
suffix = f"_{module_call_idx}"
|
||||||
|
module_dir = dump_dir / (module_name + suffix)
|
||||||
|
if is_pre_fwd:
|
||||||
|
_log_module_call(intermediate_log_config, module_name + suffix)
|
||||||
|
module_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
logger.debug("Logging module %s inputs/outputs to %s", module_name,
|
||||||
|
module_dir)
|
||||||
|
return module_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _log_module_call(
|
||||||
|
intermediate_log_config: IntermediateLoggingConfig,
|
||||||
|
module_name: str,
|
||||||
|
) -> None:
|
||||||
|
file = (Path(intermediate_log_config.output_run_dir) /
|
||||||
|
f"step_{get_step()}" / "module_calls.txt")
|
||||||
|
with open(file, "a") as f:
|
||||||
|
f.write(f"{module_name}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def update_current_step_module_call(module_name: str) -> None:
|
||||||
|
logger.debug("Updating current step module call for %s", module_name)
|
||||||
|
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||||
|
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
|
||||||
|
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
|
||||||
|
else:
|
||||||
|
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_step_module_call(module_name: str) -> int:
|
||||||
|
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_log_current_fwd(module,
|
||||||
|
is_pre_fwd: bool = False) -> Optional[Path]:
|
||||||
|
intermediate_log_config = get_current_il_config()
|
||||||
|
if intermediate_log_config is None or not intermediate_log_config.enabled:
|
||||||
|
return None
|
||||||
|
if not should_log_step(intermediate_log_config):
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_name = get_il_module_name(module)
|
||||||
|
log_call_idx = get_il_module_call_idx(module)
|
||||||
|
current_call_idx = get_current_step_module_call(module_name)
|
||||||
|
should_log = True
|
||||||
|
if log_call_idx >= 0 and current_call_idx != log_call_idx:
|
||||||
|
should_log = False
|
||||||
|
|
||||||
|
log_dir = None
|
||||||
|
if is_pre_fwd:
|
||||||
|
update_current_step_module_call(module_name)
|
||||||
|
if should_log:
|
||||||
|
log_dir = _prepare_module_log_dir(intermediate_log_config,
|
||||||
|
module_name,
|
||||||
|
is_pre_fwd=is_pre_fwd)
|
||||||
|
return log_dir
|
||||||
|
|
||||||
|
|
||||||
|
def log_pre_fwd_hook(module: torch.nn.Module,
|
||||||
|
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||||
|
"""Hook to capture module inputs before forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The PyTorch module being executed.
|
||||||
|
inputs: The inputs to the module's forward function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unchanged inputs.
|
||||||
|
"""
|
||||||
|
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
|
||||||
|
save_tensors(inputs, str(log_dir / "inputs"))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||||
|
outputs: Any) -> None:
|
||||||
|
"""Hook to capture module outputs after forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The PyTorch module being executed.
|
||||||
|
inputs: The inputs to the module's forward function.
|
||||||
|
outputs: The outputs from the module's forward function.
|
||||||
|
"""
|
||||||
|
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
|
||||||
|
save_tensors(outputs, str(log_dir / "outputs"))
|
||||||
|
intermediate_log_config = get_current_il_config()
|
||||||
|
assert intermediate_log_config is not None, \
|
||||||
|
"IL config should not be None"
|
||||||
|
if intermediate_log_config.log_post_fwd_inputs:
|
||||||
|
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_step() -> int:
|
||||||
|
"""Get the current global step counter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current global step counter.
|
||||||
|
"""
|
||||||
|
return _CURRENT_STEP
|
||||||
|
|
||||||
|
|
||||||
|
def increment_step() -> int:
|
||||||
|
"""Increment the global step counter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new step counter value.
|
||||||
|
"""
|
||||||
|
global _CURRENT_STEP
|
||||||
|
_CURRENT_STEP += 1
|
||||||
|
return _CURRENT_STEP
|
||||||
|
|
||||||
|
|
||||||
|
def reset_step() -> None:
|
||||||
|
"""Reset the global step counter to zero."""
|
||||||
|
global _CURRENT_STEP
|
||||||
|
_CURRENT_STEP = 0
|
||||||
|
|
||||||
|
|
||||||
|
class IntermediatesLogger:
|
||||||
|
"""Class to manage logging of intermediate tensors during model
|
||||||
|
execution."""
|
||||||
|
|
||||||
|
def __init__(self, config: IntermediateLoggingConfig):
|
||||||
|
self.config = config
|
||||||
|
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
|
||||||
|
Optional[RemovableHandle]]] = []
|
||||||
|
logger.debug("Created IntermediatesLogger with config: %s", config)
|
||||||
|
path = Path(config.output_run_dir)
|
||||||
|
path.mkdir(exist_ok=True, parents=True)
|
||||||
|
# Log configuration
|
||||||
|
logger.info("Intermediates will be logged in %s",
|
||||||
|
config.output_run_dir)
|
||||||
|
|
||||||
|
def register_hooks(self, model: torch.nn.Module) -> None:
|
||||||
|
"""Register hooks for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to register hooks for.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name and should_log_module(self.config, name, module):
|
||||||
|
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
|
||||||
|
logger.debug("Registered pre_fwd hook for %s",
|
||||||
|
module.__class__.__name__)
|
||||||
|
post_hook = module.register_forward_hook(log_post_fwd_hook)
|
||||||
|
logger.debug("Registered post_fwd hook for %s",
|
||||||
|
module.__class__.__name__)
|
||||||
|
self.hooks.append((name, module, pre_hook, post_hook))
|
||||||
|
|
||||||
|
# Register a step counter hook for the root model
|
||||||
|
step_hook = model.register_forward_hook(step_fwd)
|
||||||
|
self.hooks.append(("", model, None, step_hook))
|
||||||
|
logger.info("Registered hooks for %s modules", len(self.hooks))
|
||||||
|
|
||||||
|
def remove_hooks(self) -> None:
|
||||||
|
"""Remove all registered hooks."""
|
||||||
|
for _, _, pre_hook, post_hook in self.hooks:
|
||||||
|
if pre_hook is not None:
|
||||||
|
pre_hook.remove()
|
||||||
|
if post_hook is not None:
|
||||||
|
post_hook.remove()
|
||||||
|
|
||||||
|
logger.info("Removed %s hooks", len(self.hooks))
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_intermediate_hooks(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
config: Optional[IntermediateLoggingConfig] = None
|
||||||
|
) -> IntermediatesLogger:
|
||||||
|
"""Register hooks to log intermediate tensors for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to log intermediates for.
|
||||||
|
config: Configuration for intermediate logging. If provided, this takes
|
||||||
|
precedence over kwargs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An IntermediatesLogger instance that can be used to manage the hooks.
|
||||||
|
"""
|
||||||
|
logger_instance = IntermediatesLogger(config)
|
||||||
|
logger_instance.register_hooks(model)
|
||||||
|
return logger_instance
|
||||||
@ -27,6 +27,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
|
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, ModelRunnerOutput)
|
DraftTokenIds, ModelRunnerOutput)
|
||||||
@ -362,9 +363,9 @@ class Worker(WorkerBase):
|
|||||||
intermediate_tensors = IntermediateTensors(
|
intermediate_tensors = IntermediateTensors(
|
||||||
get_pp_group().recv_tensor_dict(
|
get_pp_group().recv_tensor_dict(
|
||||||
all_gather_group=get_tp_group()))
|
all_gather_group=get_tp_group()))
|
||||||
|
with intermediate_logging(self.vllm_config.intermediate_log_config):
|
||||||
output = self.model_runner.execute_model(scheduler_output,
|
output = self.model_runner.execute_model(scheduler_output,
|
||||||
intermediate_tensors)
|
intermediate_tensors)
|
||||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,10 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import IntermediateLoggingConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.intermediates.intermediates_logging import (
|
||||||
|
register_intermediate_hooks)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||||
|
|
||||||
@ -63,3 +65,26 @@ class WorkerBase(WorkerBaseV0):
|
|||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
"""Basic health check (override for device-specific checks)."""
|
"""Basic health check (override for device-specific checks)."""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def register_intermediate_hooks(
|
||||||
|
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
|
||||||
|
"""Register hooks for intermediate tensor logging.
|
||||||
|
|
||||||
|
This method is called via collective_rpc from the engine core.
|
||||||
|
It registers hooks on the model to dump intermediate tensors during
|
||||||
|
execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration for intermediate logging. If provided, this
|
||||||
|
takes precedence over kwargs.
|
||||||
|
"""
|
||||||
|
if self.model_runner is None or not hasattr(
|
||||||
|
self.model_runner, "model") or self.model_runner.model is None:
|
||||||
|
logger.error("Could not register intermediate hooks: "
|
||||||
|
"model_runner.model is not accessible")
|
||||||
|
return
|
||||||
|
model = self.model_runner.model
|
||||||
|
try:
|
||||||
|
register_intermediate_hooks(model, config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error registering intermediate hooks")
|
||||||
|
|||||||
@ -129,6 +129,22 @@ class WorkerBase:
|
|||||||
"""Get vocabulary size from model configuration."""
|
"""Get vocabulary size from model configuration."""
|
||||||
return self.model_config.get_vocab_size()
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
def register_intermediate_hooks(self, config=None) -> None:
|
||||||
|
"""Register hooks for intermediate tensor logging.
|
||||||
|
|
||||||
|
This method is a stub for v0 workers. The actual implementation is
|
||||||
|
in v1 workers. It's included here for compatibility with the
|
||||||
|
collective_rpc mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration for intermediate logging.
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"register_intermediate_hooks is not implemented in v0 workers. "
|
||||||
|
"This is only available in v1 workers. No hooks will be registered."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Clean up resources held by the worker."""
|
"""Clean up resources held by the worker."""
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user