diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 881d5efa69..909b739331 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -221,11 +221,6 @@ def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") -@pytest.fixture(scope="session") -def long_context_lora_files_16k_1(): - return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") - - @pytest.fixture def llama_2_7b_engine_extra_embeddings(): cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index f16589e06b..df8696cf58 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -38,8 +38,8 @@ ERROR_CASES = [ ] -def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): - peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1, +def test_peft_helper_pass(sql_lora_files, tmp_path): + peft_helper = PEFTHelper.from_local_dir(sql_lora_files, max_position_embeddings=4096) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) peft_helper.validate_legal(lora_config) @@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): "embed_tokens", "lm_head", ] - assert peft_helper.context_length == 16384 assert peft_helper.vllm_max_position_embeddings == 4096 - assert peft_helper.vllm_long_context_scaling_factor == float( - math.ceil(peft_helper.context_length / - peft_helper.vllm_max_position_embeddings)) + # test RSLoRA rslora_config = dict(use_rslora=True) test_dir = tmp_path / "test_rslora" - shutil.copytree(long_context_lora_files_16k_1, test_dir) + shutil.copytree(sql_lora_files, test_dir) # Load and modify configuration config_path = test_dir / "adapter_config.json" diff --git a/vllm/config.py b/vllm/config.py index 8383a663c7..384cb584fa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3014,12 +3014,7 @@ class LoRAConfig: (added to the base model vocabulary).""" lora_vocab_padding_size: ClassVar[int] = current_platform\ .get_lora_vocab_padding_size() - long_lora_scaling_factors: Optional[tuple[float, ...]] = None - """Specify multiple scaling factors (which can be different from base model - scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters - trained with those scaling factors to be used at the same time. If not - specified, only adapters trained with the base model scaling factor are - allowed.""" + default_mm_loras: Optional[dict[str, str]] = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -3052,7 +3047,6 @@ class LoRAConfig: factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) factors.append(self.lora_vocab_padding_size) - factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() @@ -3091,11 +3085,6 @@ class LoRAConfig: elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - def verify_lora_support(self): - if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: - raise ValueError( - "V1 LoRA does not support long LoRA, please use V0.") - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @@ -4564,7 +4553,6 @@ class VllmConfig: if self.lora_config is not None: self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_lora_support() if self.prompt_adapter_config is not None: self.prompt_adapter_config.verify_with_model_config( self.model_config) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a7fcf6c354..d352a22a6d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -358,8 +358,6 @@ class EngineArgs: max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size - long_lora_scaling_factors: Optional[tuple[float, ...]] = \ - LoRAConfig.long_lora_scaling_factors # PromptAdapter fields enable_prompt_adapter: bool = False max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters @@ -723,8 +721,6 @@ class EngineArgs: "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--long-lora-scaling-factors", - **lora_kwargs["long_lora_scaling_factors"]) lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) lora_group.add_argument("--fully-sharded-loras", @@ -1245,7 +1241,6 @@ class EngineArgs: default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, - long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 779f026468..c3512ec3db 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import ( - LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.platforms import current_platform @@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ) -> bool: # Special handling for the LogitsProcessor. return False - - -class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): - """Implements RoPE-scaled embeddings with linear scaling for - multiple LoRA adapters with a specialized kernel. - - Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding - which can handle multi lora adapters in a specialized kernel. - """ - - def __init__(self, base_layer: RotaryEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - - @property - def scaling_factors(self): - return self.base_layer.scaling_factors - - @property - def rotary_dim(self): - return self.base_layer.rotary_dim - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - scaling_factors = (list(lora_config.long_lora_scaling_factors) - if lora_config.long_lora_scaling_factors else []) - base_scaling_factor = (self.base_layer.scaling_factor if isinstance( - self.base_layer, LinearScalingRotaryEmbedding) else 1.0) - scaling_factors = sorted( - list(set([base_scaling_factor] + scaling_factors))) - self.base_layer = LinearScalingRotaryEmbedding( - self.base_layer.head_size, - self.base_layer.rotary_dim, - self.base_layer.max_position_embeddings, - self.base_layer.base, - self.base_layer.is_neox_style, - scaling_factors, - self.base_layer.dtype, - ) - - def reset_lora(self, index: int): - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - ... - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.base_layer( - positions, - query, - key, - offsets=self.punica_wrapper.long_lora_indices, - ) - - @property - def scaling_factor_to_offset(self) -> dict[float, int]: - return self.base_layer.scaling_factor_to_offset - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - return (type(source_layer) is LinearScalingRotaryEmbedding - or type(source_layer) is RotaryEmbedding) - - def extra_repr(self) -> str: - return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 633674d5fb..e6b19d4748 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,6 @@ import math import os from collections.abc import Sequence -from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union import regex as re @@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, - LinearScalingRotaryEmbeddingWithLoRA, - LoRAMapping) +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper @@ -43,18 +40,6 @@ logger = init_logger(__name__) _GLOBAL_LORA_ID = 0 -@dataclass -class LongContextLoRAContext: - """Context for lora adapters that support long context.""" - # The scaling factors to support long context lora fine tuned models. - scaling_factors: list[float] - # dimension to apply rotary embedding. - rot_dim: int - # offsets to the sin_cos_cache for each lora_id loaded. - # This value is dynamically modified. - offsets_by_lora_id: dict[int, int] = field(default_factory=dict) - - def get_lora_id(): global _GLOBAL_LORA_ID _GLOBAL_LORA_ID += 1 @@ -80,20 +65,16 @@ class LoRAModel(AdapterModel): lora_model_id: int, rank: int, loras: dict[str, LoRALayerWeights], - scaling_factor: Optional[float] = None, ) -> None: """ Args: lora_model_id: The integer id for the lora model. rank: lora rank. loras: module name -> weights for lora-replaced layers. - scaling_factor: Scaling factor to support long context lora model. - None if the lora is not tuned for long context support. + """ self.id = lora_model_id - # Scaling factor for long context lora model. None if it is not - # fine tuned for the long context. - self.scaling_factor = scaling_factor + assert ( lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" @@ -192,10 +173,7 @@ class LoRAModel(AdapterModel): for lora in loras.values(): lora.optimize() - return cls(lora_model_id, - peft_helper.r, - loras, - scaling_factor=peft_helper.vllm_long_context_scaling_factor) + return cls(lora_model_id, peft_helper.r, loras) @classmethod def from_local_checkpoint( @@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager): self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size - self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, max_loras=self.lora_config.max_loras) - # Scaling factor -> offset to the sin_cos_cache to it. - # Used for long context lora. - self.scaling_factor_to_offset: dict[float, int] = {} + super().__init__(model) self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" f" {self.model.__class__.__name__}." - if lora_config.long_lora_scaling_factors: - # We need to replace rotary emb layer to do batch computation - # for long lora. - self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = get_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model @@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager): except ValueError: pass - def _set_long_lora_context(self, lora: LoRAModel): - if self.long_lora_context is None: - return - - if lora.scaling_factor is None: - return - - if (lora.scaling_factor not in self.scaling_factor_to_offset): - raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" - " has not been initialized.") - - offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) - if offsets: - self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_adapters[lora.id] = lora - self._set_long_lora_context(lora) def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" @@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager): self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size, - self.long_lora_context, ) def remove_all_adapters(self): @@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager): from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) - # LinearScalingRotaryEmbeddingWithLoRA is used to handle - # long context lora. Register relevant metadata. - if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA): - self.long_lora_context = LongContextLoRAContext( - new_module.scaling_factors, new_module.rotary_dim) - self.scaling_factor_to_offset = \ - new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module_name = 'logits_processor' @@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager): self, lora_id: int, rank: int, - scaling_factor: Optional[float], embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" - model = LoRAModel(lora_id, rank, {}, scaling_factor) + model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): bias_enabled = self.lora_config.bias_enabled if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) - or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) or self._filter_unsupported_mm_module(module_name)): continue parts = module_name.split(".") @@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager): self._deactivate_adapter) def add_adapter(self, adapter: LoRAModel) -> bool: - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", adapter.id, adapter.id, - adapter.scaling_factor) + logger.debug("Adding lora. Model id: %d, " + "int id: %d", adapter.id, adapter.id) return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) @@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager): def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + logger.debug("Adding lora. Model id: %d, " + "int id: %d", lora.id, lora.id) if lora.id not in self._registered_adapters: self._add_adapter(lora) was_added = True diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 24099bf479..8b8e5cb7d5 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -35,12 +35,9 @@ class PEFTHelper: use_rslora: bool = field(default=False) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) - # long context lora field - context_length: int = field(default=0) # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) - vllm_long_context_scaling_factor: Optional[float] = field(default=None) def _validate_features(self) -> list[str]: """ @@ -59,12 +56,6 @@ class PEFTHelper: self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) else: self.vllm_lora_scaling_factor = self.lora_alpha / self.r - if self.context_length: - if self.vllm_max_position_embeddings is None: - self.vllm_max_position_embeddings = self.context_length - self.vllm_long_context_scaling_factor = float( - math.ceil(self.context_length / - self.vllm_max_position_embeddings)) @classmethod def from_dict(cls, config_dict: dict) -> "PEFTHelper": diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 5b4902dcbe..b3413de1c8 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext class PunicaWrapperABC(ABC): @@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC): max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, **kwargs, ) -> None: """ @@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC): max_num_batched_tokens, dtype=torch.long, device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - # 5 is the number of indices tensors. + # 4 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices,long_lora_indices - self.indices_len: list[Optional[int]] = [None] * 5 + # embeddings_indices + self.indices_len: list[Optional[int]] = [None] * 4 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, @@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC): max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, ): ( base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_offsets_tensor, indices_len, ) = convert_mapping( mapping, @@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC): vocab_size, extra_vocab_size, self.device, - long_lora_context, ) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) @@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC): self._embeddings_indices[:embeddings_indices. shape[0], :embeddings_indices.shape[1]].copy_( embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() + self.indices_len[:] = indices_len def _update_prefill_metadata(self, @@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC): embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLoRA. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - **kwargs): + def update_metadata(self, mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], max_loras: int, + vocab_size: int, extra_vocab_size: int, **kwargs): self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size) + if mapping.is_prefill: # Update metadata required for prefill-related operators. self._update_prefill_metadata(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 6b038309d5..2db0e9fee1 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, Optional, Union, final +from typing import Optional, Union, final import torch @@ -21,10 +21,6 @@ if HAS_TRITON: from .punica_base import PunicaWrapperBase -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.models import LongContextLoRAContext - @final class PunicaWrapperGPU(PunicaWrapperBase): @@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_num_prompts, device=device) - def update_metadata( - self, - mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - **kwargs): + def update_metadata(self, mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], max_loras: int, + vocab_size: int, extra_vocab_size: int, **kwargs): self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6b48268c50..07dc337a1c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext from .punica_base import PunicaWrapperBase @@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) @@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, ): # Make sure we don't accidentally collect outside operations xm.mark_step() @@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_offsets_tensor, indices_len, ) = convert_mapping( mapping, @@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): vocab_size, extra_vocab_size, "cpu", - long_lora_context, ) self._token_lora_indices = self._pad_to_shape( base_indices, self._token_lora_indices.shape, @@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase): self._embeddings_indices = self._pad_to_shape( embeddings_indices, self._embeddings_indices.shape, dims=2).to(self.device) - if long_lora_offsets_tensor is not None: - self._long_lora_indices = self._pad_to_shape( - long_lora_offsets_tensor, - self._long_lora_indices.shape, - dims=1).to(self.device) - else: - zeroed = torch.zeros_like(self._long_lora_indices.cpu(), - dtype=torch.int32) - self._long_lora_indices = zeroed.to(self.device) self.indices_len[:] = indices_len def _update_prefill_metadata(self, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 8430cb9186..d22c29da1c 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -8,7 +8,6 @@ import torch if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext def compute_meta( @@ -49,9 +48,7 @@ def convert_mapping( vocab_size: int, extra_vocab_size: int, device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], list[int]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: """Converts LoRAMapping to index tensors. Args: @@ -60,7 +57,6 @@ def convert_mapping( max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: @@ -78,21 +74,14 @@ def convert_mapping( requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. It contains (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). + embeddings_indices). """ index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + prompt_mapping: list[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -104,20 +93,13 @@ def convert_mapping( if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset indices_list: list[Union[list[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices, ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) prompt_mapping_tensor = torch.tensor(prompt_mapping, dtype=torch.long, @@ -136,11 +118,7 @@ def convert_mapping( sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. indices_len = [ base_indices.shape[-1], @@ -148,17 +126,11 @@ def convert_mapping( sampler_indices_padded.shape[-1], embeddings_indices.shape[-1], ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) return ( base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_indices, indices_len, ) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 7148ffe149..ab0a9fbd25 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import ( # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, @@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA, - LinearScalingRotaryEmbeddingWithLoRA, } diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7a4af74cbe..248d2954f1 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager): lora_request.lora_int_id) else: dummy_lora = self._adapter_manager.create_dummy_lora( - lora_request.lora_int_id, rank, 1, self.embedding_modules) + lora_request.lora_int_id, rank, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora)