Set default dtype to half

This commit is contained in:
Woosuk Kwon
2023-02-23 21:31:39 +00:00
parent de0fabbc5c
commit 1ce1333573
3 changed files with 23 additions and 3 deletions

View File

@ -17,6 +17,7 @@ class Worker:
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
) -> None:
self.worker_id = worker_id
self.gpu_id = gpu_id
@ -26,7 +27,7 @@ class Worker:
# Initialize the model.
# FIXME(woosuk): This is a hack.
self.model = get_model(model_name).to(device=gpu_id)
self.model = get_model(model_name, dtype=dtype).to(device=self.device)
self.num_layers = self.model.config.num_hidden_layers
self.num_heads = self.model.config.num_attention_heads
self.head_size = self.model.config.hidden_size // self.num_heads