[Hardware][Neuron] Refactor neuron support (#3471)

This commit is contained in:
Zhuohan Li
2024-03-21 18:22:17 -07:00
committed by GitHub
parent ea5f14e6ff
commit e90fc21f2e
33 changed files with 615 additions and 549 deletions

View File

@ -33,7 +33,7 @@ def test_worker_apply_lora(sql_lora_files):
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.init_device()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))

View File

@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
vocab_size = 32_000
@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):
@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
@ -499,11 +499,11 @@ def test_init_model():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
draft_worker.init_model.assert_called_once()
draft_worker.init_device.assert_called_once()
target_worker.init_model.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()

View File

@ -123,7 +123,7 @@ def create_worker(cls: type,
is_driver_worker=is_driver_worker,
)
worker.init_model()
worker.init_device()
worker.load_model()
cache_config.num_gpu_blocks = num_gpu_blocks

View File

@ -30,7 +30,7 @@ def test_swap() -> None:
)
# Initialize the worker.
worker.init_model()
worker.init_device()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.warm_up_model()