[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-12 11:10:15 +08:00
committed by GitHub
parent eea55cca5b
commit 7f5edb5900
6 changed files with 174 additions and 80 deletions

View File

@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager:
def __init__(self):
def __init__(self, device: torch.device = "cuda:0"):
super().__init__()
self._loras: Dict[str, LoRALayerWeights] = {}
self._device = device
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
@ -28,16 +29,16 @@ class DummyLoRAManager:
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device="cuda"),
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device="cuda"),
device=self._device),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device="cuda")
device=self._device)
self.set_module_lora(module_name, lora)
return lora