added __init__.py
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
@ -58,11 +58,14 @@ class PDController(EngineClient):
|
||||
[ Engine ] <---> [ Engine ]
|
||||
|
||||
After PR #12957, we will support xPyD, so we will
|
||||
also need to implement a scheduler.
|
||||
we will need to support multiple
|
||||
also need to implement a scheduler and service
|
||||
discovery for the workers.
|
||||
|
||||
* TODO: actually handle errors and failure.
|
||||
* TODO: support the full API (logprobs, multimodal).
|
||||
This PDController may be implemented as a K8s
|
||||
controller. This is intended to be a prototype.
|
||||
|
||||
* TODO: better error handling
|
||||
* TODO: support logprobs, multimodal, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||
@ -101,17 +104,16 @@ class PDController(EngineClient):
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor OAI Server to avoid needing this.
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=False,
|
||||
max_num_seqs=1024,
|
||||
max_loras=0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision,
|
||||
truncation_side=self.model_config.truncation_side)
|
||||
self.tokenizer = TokenizerGroup(**init_kwargs)
|
||||
self.tokenizer = TokenizerGroup(
|
||||
**dict(tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=False,
|
||||
max_num_seqs=1024,
|
||||
max_loras=0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision,
|
||||
truncation_side=self.model_config.truncation_side))
|
||||
|
||||
def shutdown(self):
|
||||
if (ctx := self.ctx) is not None:
|
||||
@ -155,7 +157,7 @@ class PDController(EngineClient):
|
||||
raise Exception("Unknown response type.")
|
||||
except Exception as e:
|
||||
# TODO: distinguish between fatal and non-fatal errors.
|
||||
for _, q in self.queues.values():
|
||||
for q in self.queues.values():
|
||||
q.put_nowait(e)
|
||||
raise e
|
||||
finally:
|
||||
@ -172,17 +174,17 @@ class PDController(EngineClient):
|
||||
msg = (PDRequestType.GENERATION, req_bytes)
|
||||
await self.to_prefill.send_multipart(msg, copy=False)
|
||||
|
||||
# Wait for the prefill to be done.
|
||||
# Await completion of the prefill.
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Got Decode Response: %s", request.request_id)
|
||||
|
||||
async def _run_decode(
|
||||
self,
|
||||
request: PDGenerationRequest,
|
||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||
) -> AsyncGenerator[PDGenerationResponse]:
|
||||
|
||||
# Send request to the decode instance.
|
||||
req_bytes = self.encoder.encode(request)
|
||||
msg = (PDRequestType.GENERATION, req_bytes)
|
||||
@ -194,6 +196,7 @@ class PDController(EngineClient):
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Got Decode Response: %s", request.request_id)
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
@ -261,10 +264,10 @@ class PDController(EngineClient):
|
||||
|
||||
# (1) Perform the Prefill.
|
||||
original_max_tokens = sampling_params.max_tokens
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
request = PDGenerationRequest(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
request = PDGenerationRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt["prompt_token_ids"],
|
||||
sampling_params=sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
logger.debug("Sending Prefill: %s", request.request_id)
|
||||
pd_response = await self._run_prefill(request, q)
|
||||
@ -273,8 +276,8 @@ class PDController(EngineClient):
|
||||
logger.debug("Sending Decode: %s", request.request_id)
|
||||
request.sampling_params.max_tokens = original_max_tokens
|
||||
async for pd_response in self._run_decode(request, q):
|
||||
logger.debug("Got Response: %s", request.request_id)
|
||||
yield self._to_request_output(pd_response, prompt_token_ids)
|
||||
yield self._to_request_output(pd_response,
|
||||
prompt["prompt_token_ids"])
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user