[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
class DummyLoRAManager:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device: torch.device = "cuda:0"):
|
||||
super().__init__()
|
||||
self._loras: Dict[str, LoRALayerWeights] = {}
|
||||
self._device = device
|
||||
|
||||
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||
self._loras[module_name] = lora
|
||||
@ -28,16 +29,16 @@ class DummyLoRAManager:
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([weight.shape[1], rank],
|
||||
dtype=weight.dtype,
|
||||
device="cuda"),
|
||||
device=self._device),
|
||||
lora_b=torch.rand([rank, weight.shape[0]],
|
||||
dtype=weight.dtype,
|
||||
device="cuda"),
|
||||
device=self._device),
|
||||
)
|
||||
if generate_embeddings_tensor:
|
||||
lora.embeddings_tensor = torch.rand(5,
|
||||
generate_embeddings_tensor,
|
||||
dtype=weight.dtype,
|
||||
device="cuda")
|
||||
device=self._device)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
return lora
|
||||
|
||||
Reference in New Issue
Block a user