revert logger changes

Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-07-20 19:11:43 +00:00
parent 1b488f8d5a
commit e540aa41b8

View File

@ -36,9 +36,10 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
@ -95,18 +96,12 @@ class AsyncLLM(EngineClient):
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
# HACK: asyncllm should not be aware of how many engines is it
# managing.
start_idx = vllm_config.parallel_config.data_parallel_rank
local_engines = vllm_config.parallel_config.data_parallel_size_local
engine_idxs = [
idx for idx in range(start_idx, start_idx + local_engines)
]
self.logger_manager = StatLoggerManager(
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
engine_idxs=engine_idxs,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
) if self.log_stats else None
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
@ -134,8 +129,9 @@ class AsyncLLM(EngineClient):
client_addresses=client_addresses,
client_index=client_index,
)
if self.logger_manager:
self.logger_manager.log_engine_initialized()
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
@ -374,7 +370,7 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
stat_loggers = self.stat_loggers if log_stats else None
async def output_handler():
try:
@ -415,11 +411,11 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
# NOTE: we do not use self.log
if logger_manager:
logger_manager.record(
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
engine_idx=outputs.engine_index,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
@ -436,6 +432,18 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode(
self,
prompt: PromptType,
@ -540,11 +548,7 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
if self.stat_loggers is None:
return
# loggers, prom_logger
per_engine_loggers, _ = self.stat_loggers
for loggers in per_engine_loggers.values():
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
@ -649,19 +653,18 @@ class AsyncLLM(EngineClient):
self.vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
# recreate stat loggers
# if new_data_parallel_size > old_data_parallel_size:
# stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
# vllm_config=self.vllm_config,
# log_stats=self.log_stats,
# engine_num=new_data_parallel_size,
# custom_stat_loggers=None,
# )
# num_new_engines = len(stat_loggers) - len(self.stat_loggers)
# self.stat_loggers.extend(stat_loggers[-num_new_engines:])
# else:
# for _ in range(old_data_parallel_size - new_data_parallel_size):
# self.stat_loggers.pop()
if new_data_parallel_size > old_data_parallel_size:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=self.vllm_config,
log_stats=self.log_stats,
engine_num=new_data_parallel_size,
custom_stat_loggers=None,
)
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()
@property
def is_running(self) -> bool: