diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b8a40f77df..7710fb5fbe 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -36,9 +36,10 @@ from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager +from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, + setup_default_loggers) from vllm.v1.metrics.prometheus import shutdown_prometheus -from vllm.v1.metrics.stats import IterationStats +from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -95,18 +96,12 @@ class AsyncLLM(EngineClient): self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. - # HACK: asyncllm should not be aware of how many engines is it - # managing. - start_idx = vllm_config.parallel_config.data_parallel_rank - local_engines = vllm_config.parallel_config.data_parallel_size_local - engine_idxs = [ - idx for idx in range(start_idx, start_idx + local_engines) - ] - self.logger_manager = StatLoggerManager( + self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( vllm_config=vllm_config, - engine_idxs=engine_idxs, + log_stats=self.log_stats, + engine_num=vllm_config.parallel_config.data_parallel_size, custom_stat_loggers=stat_loggers, - ) if self.log_stats else None + ) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -134,8 +129,9 @@ class AsyncLLM(EngineClient): client_addresses=client_addresses, client_index=client_index, ) - if self.logger_manager: - self.logger_manager.log_engine_initialized() + if self.stat_loggers: + for stat_logger in self.stat_loggers[0]: + stat_logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None try: # Start output handler eagerly if we are in the asyncio eventloop. @@ -374,7 +370,7 @@ class AsyncLLM(EngineClient): engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats - logger_manager = self.logger_manager + stat_loggers = self.stat_loggers if log_stats else None async def output_handler(): try: @@ -415,11 +411,11 @@ class AsyncLLM(EngineClient): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. # NOTE: we do not use self.log - if logger_manager: - logger_manager.record( + if stat_loggers: + AsyncLLM._record_stats( + stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, - engine_idx=outputs.engine_index, ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") @@ -436,6 +432,18 @@ class AsyncLLM(EngineClient): if self.log_requests: logger.info("Aborted request %s.", request_id) + @staticmethod + def _record_stats( + stat_loggers: list[StatLoggerBase], + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + ): + """static so that it can be used from the output_handler task + without a circular ref to AsyncLLM.""" + for stat_logger in stat_loggers: + stat_logger.record(scheduler_stats=scheduler_stats, + iteration_stats=iteration_stats) + async def encode( self, prompt: PromptType, @@ -540,11 +548,7 @@ class AsyncLLM(EngineClient): scheduler_outputs=None, model_output=None, ) -> None: - if self.stat_loggers is None: - return - # loggers, prom_logger - per_engine_loggers, _ = self.stat_loggers - for loggers in per_engine_loggers.values(): + for loggers in self.stat_loggers: for stat_logger in loggers: stat_logger.log() @@ -649,19 +653,18 @@ class AsyncLLM(EngineClient): self.vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size - # recreate stat loggers - # if new_data_parallel_size > old_data_parallel_size: - # stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( - # vllm_config=self.vllm_config, - # log_stats=self.log_stats, - # engine_num=new_data_parallel_size, - # custom_stat_loggers=None, - # ) - # num_new_engines = len(stat_loggers) - len(self.stat_loggers) - # self.stat_loggers.extend(stat_loggers[-num_new_engines:]) - # else: - # for _ in range(old_data_parallel_size - new_data_parallel_size): - # self.stat_loggers.pop() + if new_data_parallel_size > old_data_parallel_size: + stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( + vllm_config=self.vllm_config, + log_stats=self.log_stats, + engine_num=new_data_parallel_size, + custom_stat_loggers=None, + ) + num_new_engines = len(stat_loggers) - len(self.stat_loggers) + self.stat_loggers.extend(stat_loggers[-num_new_engines:]) + else: + for _ in range(old_data_parallel_size - new_data_parallel_size): + self.stat_loggers.pop() @property def is_running(self) -> bool: