Set default dtype to half
This commit is contained in:
@ -17,6 +17,7 @@ class Worker:
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
@ -26,7 +27,7 @@ class Worker:
|
||||
|
||||
# Initialize the model.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
self.model = get_model(model_name).to(device=gpu_id)
|
||||
self.model = get_model(model_name, dtype=dtype).to(device=self.device)
|
||||
self.num_layers = self.model.config.num_hidden_layers
|
||||
self.num_heads = self.model.config.num_attention_heads
|
||||
self.head_size = self.model.config.hidden_size // self.num_heads
|
||||
|
||||
Reference in New Issue
Block a user