From d5b0db449e3479cbd35aa0bacaa02a3b8f4e6632 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 23 Mar 2025 22:44:36 +0000 Subject: [PATCH] added __init__.py Signed-off-by: rshaw@neuralmagic.com --- vllm/disaggregated/pd_client.py | 51 +++++++++++++++++---------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/vllm/disaggregated/pd_client.py b/vllm/disaggregated/pd_client.py index 98657d0e21..15f41e84e8 100644 --- a/vllm/disaggregated/pd_client.py +++ b/vllm/disaggregated/pd_client.py @@ -58,11 +58,14 @@ class PDController(EngineClient): [ Engine ] <---> [ Engine ] After PR #12957, we will support xPyD, so we will - also need to implement a scheduler. - we will need to support multiple + also need to implement a scheduler and service + discovery for the workers. - * TODO: actually handle errors and failure. - * TODO: support the full API (logprobs, multimodal). + This PDController may be implemented as a K8s + controller. This is intended to be a prototype. + + * TODO: better error handling + * TODO: support logprobs, multimodal, etc. """ def __init__(self, prefill_addr: str, decode_addr: str, @@ -101,17 +104,16 @@ class PDController(EngineClient): # Dummy: needed for EngineClient Protocol. # TODO: refactor OAI Server to avoid needing this. - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=False, - max_num_seqs=1024, - max_loras=0, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision, - truncation_side=self.model_config.truncation_side) - self.tokenizer = TokenizerGroup(**init_kwargs) + self.tokenizer = TokenizerGroup( + **dict(tokenizer_id=self.model_config.tokenizer, + enable_lora=False, + max_num_seqs=1024, + max_loras=0, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision, + truncation_side=self.model_config.truncation_side)) def shutdown(self): if (ctx := self.ctx) is not None: @@ -155,7 +157,7 @@ class PDController(EngineClient): raise Exception("Unknown response type.") except Exception as e: # TODO: distinguish between fatal and non-fatal errors. - for _, q in self.queues.values(): + for q in self.queues.values(): q.put_nowait(e) raise e finally: @@ -172,17 +174,17 @@ class PDController(EngineClient): msg = (PDRequestType.GENERATION, req_bytes) await self.to_prefill.send_multipart(msg, copy=False) - # Wait for the prefill to be done. + # Await completion of the prefill. response = await q.get() if isinstance(response, Exception): raise response + logger.debug("Got Decode Response: %s", request.request_id) async def _run_decode( self, request: PDGenerationRequest, q: asyncio.Queue[Union[Exception, PDGenerationResponse]], ) -> AsyncGenerator[PDGenerationResponse]: - # Send request to the decode instance. req_bytes = self.encoder.encode(request) msg = (PDRequestType.GENERATION, req_bytes) @@ -194,6 +196,7 @@ class PDController(EngineClient): response = await q.get() if isinstance(response, Exception): raise response + logger.debug("Got Decode Response: %s", request.request_id) finished = response.finish_reason is not None yield response @@ -261,10 +264,10 @@ class PDController(EngineClient): # (1) Perform the Prefill. original_max_tokens = sampling_params.max_tokens - prompt_token_ids = prompt["prompt_token_ids"] - request = PDGenerationRequest(request_id=request_id, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) + request = PDGenerationRequest( + request_id=request_id, + prompt_token_ids=prompt["prompt_token_ids"], + sampling_params=sampling_params) request.sampling_params.max_tokens = 1 logger.debug("Sending Prefill: %s", request.request_id) pd_response = await self._run_prefill(request, q) @@ -273,8 +276,8 @@ class PDController(EngineClient): logger.debug("Sending Decode: %s", request.request_id) request.sampling_params.max_tokens = original_max_tokens async for pd_response in self._run_decode(request, q): - logger.debug("Got Response: %s", request.request_id) - yield self._to_request_output(pd_response, prompt_token_ids) + yield self._to_request_output(pd_response, + prompt["prompt_token_ids"]) async def beam_search( self,