Compare commits
2 Commits
woosuk/min
...
support_gl
| Author | SHA1 | Date | |
|---|---|---|---|
| 98d535eb4f | |||
| a46e279909 |
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
"""
|
||||
To run this example, run the following commands simultaneously with
|
||||
@ -22,37 +25,67 @@ send a request to the instance with DP rank 1.
|
||||
"""
|
||||
|
||||
|
||||
def _do_background_logging(engine, interval, stop_event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
asyncio.run(engine.do_log_stats())
|
||||
stop_event.wait(interval)
|
||||
except Exception as e:
|
||||
print(f"vLLM background logging shutdown: {e}")
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
dtype="auto",
|
||||
max_model_len=2048,
|
||||
data_parallel_address="127.0.0.1",
|
||||
data_parallel_rpc_port=62300,
|
||||
data_parallel_size_local=1,
|
||||
enforce_eager=True,
|
||||
enable_log_requests=True,
|
||||
disable_custom_all_reduce=True,
|
||||
)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
||||
return LoggingStatLogger(config, rank)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args,
|
||||
# Example: Using both regular loggers and aggregated logger
|
||||
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
|
||||
)
|
||||
stop_logging_event = threading.Event()
|
||||
logging_thread = threading.Thread(
|
||||
target=_do_background_logging,
|
||||
args=(engine_client, 5, stop_logging_event),
|
||||
daemon=True,
|
||||
)
|
||||
logging_thread.start()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
num_prompts = 10
|
||||
for i in range(num_prompts):
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=f"abcdef-{i}",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id="abcdef",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
stop_logging_event.set()
|
||||
logging_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
class MockAggregatedStatLogger(AggregatedStatLogger):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
super().__init__(vllm_config, engine_indexes)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
@ -415,6 +424,35 @@ async def test_customize_loggers(monkeypatch):
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_aggregated_loggers(monkeypatch):
|
||||
"""Test that we can customize the aggregated loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(
|
||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
aggregated_loggers = engine.logger_manager.aggregated_loggers
|
||||
assert len(aggregated_loggers) == 1
|
||||
aggregated_loggers[0].log.assert_called_once()
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE,
|
||||
SupportsMultiModal, SupportsPP, SupportsTranscription,
|
||||
SupportsV0Only, has_inner_state, supports_lora,
|
||||
supports_mrope, supports_multimodal, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription, SupportsV0Only,
|
||||
has_inner_state, supports_lora, supports_multimodal,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
@ -22,8 +21,6 @@ __all__ = [
|
||||
"supports_lora",
|
||||
"SupportsMultiModal",
|
||||
"supports_multimodal",
|
||||
"SupportsMRoPE",
|
||||
"supports_mrope",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
"SupportsTranscription",
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
@ -853,70 +852,3 @@ def supports_eagle3(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
|
||||
return isinstance(model, SupportsEagle3)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMRoPE(Protocol):
|
||||
"""The interface required for all models that support M-RoPE."""
|
||||
|
||||
supports_mrope: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports M-RoPE.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Get M-RoPE input positions and delta value for this specific model.
|
||||
|
||||
This method should be implemented by each model that supports M-RoPE
|
||||
to provide model-specific logic for computing input positions.
|
||||
|
||||
Args:
|
||||
input_tokens: List of input token IDs
|
||||
hf_config: HuggingFace model configuration
|
||||
image_grid_thw: Image grid dimensions (t, h, w)
|
||||
video_grid_thw: Video grid dimensions (t, h, w)
|
||||
second_per_grid_ts: Seconds per grid timestep for videos
|
||||
context_len: Context length
|
||||
seq_len: Sequence length
|
||||
audio_feature_lengths: Audio feature lengths for multimodal models
|
||||
use_audio_in_video: Whether to use audio in video for interleaving
|
||||
|
||||
Returns:
|
||||
Tuple of (llm_positions, mrope_position_delta)
|
||||
- llm_positions: Tensor of shape [3, num_tokens]
|
||||
with T/H/W positions
|
||||
- mrope_position_delta: Delta for position calculations
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]:
|
||||
...
|
||||
|
||||
|
||||
def supports_mrope(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]:
|
||||
return isinstance(model, SupportsMRoPE)
|
||||
|
||||
@ -32,7 +32,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import AutoConfig, BatchFeature, PretrainedConfig
|
||||
from transformers import AutoConfig, BatchFeature
|
||||
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
||||
Qwen2VLProcessor)
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
@ -73,7 +73,7 @@ from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
@ -1096,7 +1096,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
info=Qwen2VLProcessingInfo,
|
||||
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP, SupportsMRoPE):
|
||||
SupportsLoRA, SupportsPP):
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -1109,118 +1109,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get M-RoPE input positions for Qwen2-VL model."""
|
||||
if image_grid_thw is None:
|
||||
image_grid_thw = []
|
||||
if video_grid_thw is None:
|
||||
video_grid_thw = []
|
||||
if second_per_grid_ts is None:
|
||||
second_per_grid_ts = []
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(hf_config.vision_config,
|
||||
"tokens_per_second", 1.0)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_videos > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_second_per_grid_t = 1.0
|
||||
if second_per_grid_ts:
|
||||
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
|
||||
tokens_per_second).long().flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
|
||||
@ -206,11 +206,12 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
returned_lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
if self.need_to_return_lse_for_decode:
|
||||
lse = lse[:, :H].contiguous()
|
||||
|
||||
return out, lse
|
||||
return out, returned_lse
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
|
||||
type["AggregatedStatLoggerBase"]]
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
@ -48,6 +50,16 @@ class StatLoggerBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AggregatedStatLoggerBase(StatLoggerBase):
|
||||
"""Abstract base class for loggers that
|
||||
aggregates statistics across multiple engines."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]]):
|
||||
...
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
@ -61,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
self.engine_is_idle = False
|
||||
|
||||
def _reset(self, now):
|
||||
self.last_log_time = now
|
||||
@ -100,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
def get_log_stats(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
self.engine_is_idle = not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput))
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
@ -128,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
@ -145,7 +158,61 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_idxs: Optional[list[int]] = None):
|
||||
if engine_idxs is None:
|
||||
engine_idxs = [0]
|
||||
self.engine_idxs = engine_idxs
|
||||
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
|
||||
|
||||
def record(
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
if engine_idx not in self.engine_idxs:
|
||||
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||
return
|
||||
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"%s Engines Aggregated: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
len(self.engine_idxs),
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"%d Engines: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", len(self.engine_idxs),
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(AggregatedStatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
@ -674,23 +741,32 @@ class StatLoggerManager:
|
||||
|
||||
# engine_idx: StatLogger
|
||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||
prometheus_factory = PrometheusStatLogger
|
||||
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
|
||||
|
||||
aggregated_loggers_factories = set()
|
||||
for engine_idx in self.engine_idxs:
|
||||
loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
# If we get a custom prometheus logger, use that
|
||||
# instead. This is typically used for the ray case.
|
||||
if (isinstance(logger_factory, type)
|
||||
and issubclass(logger_factory, PrometheusStatLogger)):
|
||||
prometheus_factory = logger_factory
|
||||
continue
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
# If we get a custom prometheus logger or aggregated logger,
|
||||
# We initialize it separately with all engine idxs.
|
||||
# A custom prometheus logger is typically used for the ray.
|
||||
if (isinstance(logger_factory, type) and issubclass(
|
||||
logger_factory, AggregatedStatLoggerBase)):
|
||||
aggregated_loggers_factories.add(logger_factory)
|
||||
else:
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
self.per_engine_logger_dict[engine_idx] = loggers
|
||||
|
||||
# For Prometheus, need to share the metrics between EngineCores.
|
||||
# If no custom aggregated logger is provide,
|
||||
# we by default use PrometheusStatLogger
|
||||
if not aggregated_loggers_factories:
|
||||
aggregated_loggers_factories.add(PrometheusStatLogger)
|
||||
# For custom aggregated logger(or default Prometheus Logger)
|
||||
# need to share the metrics between EngineCores.
|
||||
# Each EngineCore's metrics are expressed as a unique label.
|
||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
||||
for aggregated_loggers_factory in aggregated_loggers_factories:
|
||||
self.aggregated_loggers.append(
|
||||
aggregated_loggers_factory(vllm_config, engine_idxs))
|
||||
|
||||
def record(
|
||||
self,
|
||||
@ -704,18 +780,19 @@ class StatLoggerManager:
|
||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||
for logger in per_engine_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log()
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.log()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.prometheus_logger.log_engine_initialized()
|
||||
|
||||
for agg_logger in self.aggregated_loggers:
|
||||
agg_logger.log_engine_initialized()
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log_engine_initialized()
|
||||
|
||||
@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mrope,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
@ -731,28 +730,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
if supports_mrope(self.model):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
else:
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
def _extract_mm_kwargs(
|
||||
self,
|
||||
|
||||
@ -334,8 +334,7 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_run(size,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
if self.lora_config is not None:
|
||||
self.model_runner.maybe_remove_all_loras(self.lora_config)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
|
||||
# Warmup and tune the kernels used during model execution before
|
||||
# cuda graph capture.
|
||||
@ -428,9 +427,6 @@ class Worker(WorkerBase):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
if len(get_pp_group().ranks) == 1:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -449,6 +445,8 @@ class Worker(WorkerBase):
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
@ -41,8 +41,7 @@ from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
get_sampler)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import (supports_lora, supports_mrope,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
@ -671,33 +670,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.seq_ids[seq_idx]]
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
if supports_mrope(self.runner.model):
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
self.runner.model.get_mrope_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
mrope_input_positions = mrope_input_positions.tolist()
|
||||
else:
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
inter_data.mrope_input_positions[
|
||||
|
||||
Reference in New Issue
Block a user