added __init__.py

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com
2025-03-23 22:44:36 +00:00
parent 66349c33a1
commit d5b0db449e

View File

@ -58,11 +58,14 @@ class PDController(EngineClient):
[ Engine ] <---> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler.
we will need to support multiple
also need to implement a scheduler and service
discovery for the workers.
* TODO: actually handle errors and failure.
* TODO: support the full API (logprobs, multimodal).
This PDController may be implemented as a K8s
controller. This is intended to be a prototype.
* TODO: better error handling
* TODO: support logprobs, multimodal, etc.
"""
def __init__(self, prefill_addr: str, decode_addr: str,
@ -101,17 +104,16 @@ class PDController(EngineClient):
# Dummy: needed for EngineClient Protocol.
# TODO: refactor OAI Server to avoid needing this.
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side)
self.tokenizer = TokenizerGroup(**init_kwargs)
self.tokenizer = TokenizerGroup(
**dict(tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side))
def shutdown(self):
if (ctx := self.ctx) is not None:
@ -155,7 +157,7 @@ class PDController(EngineClient):
raise Exception("Unknown response type.")
except Exception as e:
# TODO: distinguish between fatal and non-fatal errors.
for _, q in self.queues.values():
for q in self.queues.values():
q.put_nowait(e)
raise e
finally:
@ -172,17 +174,17 @@ class PDController(EngineClient):
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_prefill.send_multipart(msg, copy=False)
# Wait for the prefill to be done.
# Await completion of the prefill.
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
async def _run_decode(
self,
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
) -> AsyncGenerator[PDGenerationResponse]:
# Send request to the decode instance.
req_bytes = self.encoder.encode(request)
msg = (PDRequestType.GENERATION, req_bytes)
@ -194,6 +196,7 @@ class PDController(EngineClient):
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
finished = response.finish_reason is not None
yield response
@ -261,10 +264,10 @@ class PDController(EngineClient):
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"]
request = PDGenerationRequest(request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
request = PDGenerationRequest(
request_id=request_id,
prompt_token_ids=prompt["prompt_token_ids"],
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._run_prefill(request, q)
@ -273,8 +276,8 @@ class PDController(EngineClient):
logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._run_decode(request, q):
logger.debug("Got Response: %s", request.request_id)
yield self._to_request_output(pd_response, prompt_token_ids)
yield self._to_request_output(pd_response,
prompt["prompt_token_ids"])
async def beam_search(
self,