fix
This commit is contained in:
@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="facebook/opt-125m", compilation_config={"level": 0, "cudagraph_mode": "full_decode_only"})
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
@ -412,6 +412,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return self._workspace_buffer
|
||||
|
||||
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
|
||||
self._workspace_buffer = workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
|
||||
@ -77,7 +77,7 @@ def init_attn_backend(
|
||||
|
||||
if "FLASHINFER" in attn_backend.get_name():
|
||||
if flashinfer_workspace is None:
|
||||
flashinfer_workspace = attn_metadata_builder.get_workspace_buffer()
|
||||
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
|
||||
else:
|
||||
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
@ -792,9 +792,8 @@ class Worker(WorkerBase):
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# if runner := getattr(self, "model_runner", None):
|
||||
# runner.ensure_kv_transfer_shutdown()
|
||||
pass
|
||||
if runner := getattr(self, "model_runner", None):
|
||||
runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
|
||||
Reference in New Issue
Block a user