[Core] Support dynamically loading Lora adapter from HuggingFace (#6234)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
lora_path="abc"))
|
||||
waiting.append(seq_group)
|
||||
# Add two more requests to verify lora is prioritized.
|
||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||
@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
lora_path="abc"))
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
|
||||
@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sql_lora_files():
|
||||
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
def sql_lora_huggingface_id():
|
||||
# huggingface repo id is used to test lora runtime downloading.
|
||||
return "yard1/llama-2-7b-sql-lora-test"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sql_lora_files(sql_lora_huggingface_id):
|
||||
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos):
|
||||
context_len = long_context_infos[lora_id]["context_length"]
|
||||
scaling_factor = context_len_to_scaling_factor[context_len]
|
||||
return LoRARequest(context_len, lora_id,
|
||||
long_context_infos[lora_id]["lora"],
|
||||
long_context_infos[lora_id]["lora"], None,
|
||||
4096 * scaling_factor)
|
||||
|
||||
|
||||
|
||||
39
tests/lora/test_lora_huggingface.py
Normal file
39
tests/lora/test_lora_huggingface.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
# Provide absolute path and huggingface lora ids
|
||||
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
|
||||
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||
lora_name = request.getfixturevalue(lora_fixture_name)
|
||||
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
||||
embedding_modules = LlamaForCausalLM.embedding_modules
|
||||
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
lora_path = get_adapter_absolute_path(lora_name)
|
||||
|
||||
# lora loading should work for either absolute path and hugggingface id.
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
|
||||
# Assertions to ensure the model is loaded correctly
|
||||
assert lora_model is not None, "LoRAModel is not loaded correctly"
|
||||
@ -1,9 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
from torch import nn
|
||||
|
||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
|
||||
@ -182,3 +185,55 @@ def test_lru_cache():
|
||||
assert 2 in cache
|
||||
assert 4 in cache
|
||||
assert 6 in cache
|
||||
|
||||
|
||||
# Unit tests for get_adapter_absolute_path
|
||||
@patch('os.path.isabs')
|
||||
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
||||
path = '/absolute/path/to/lora'
|
||||
mock_isabs.return_value = True
|
||||
assert get_adapter_absolute_path(path) == path
|
||||
|
||||
|
||||
@patch('os.path.expanduser')
|
||||
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
|
||||
# Path with ~ that needs to be expanded
|
||||
path = '~/relative/path/to/lora'
|
||||
absolute_path = '/home/user/relative/path/to/lora'
|
||||
mock_expanduser.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.abspath')
|
||||
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
|
||||
# Relative path that exists locally
|
||||
path = 'relative/path/to/lora'
|
||||
absolute_path = '/absolute/path/to/lora'
|
||||
mock_exist.return_value = True
|
||||
mock_abspath.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('huggingface_hub.snapshot_download')
|
||||
@patch('os.path.exists')
|
||||
def test_get_adapter_absolute_path_huggingface(mock_exist,
|
||||
mock_snapshot_download):
|
||||
# Hugging Face model identifier
|
||||
path = 'org/repo'
|
||||
absolute_path = '/mock/snapshot/path'
|
||||
mock_exist.return_value = False
|
||||
mock_snapshot_download.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('huggingface_hub.snapshot_download')
|
||||
@patch('os.path.exists')
|
||||
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
|
||||
mock_snapshot_download):
|
||||
# Hugging Face model identifier with download error
|
||||
path = 'org/repo'
|
||||
mock_exist.return_value = False
|
||||
mock_snapshot_download.side_effect = HfHubHTTPError(
|
||||
"failed to query model info")
|
||||
assert get_adapter_absolute_path(path) == path
|
||||
|
||||
Reference in New Issue
Block a user