Update deprecated type hinting in vllm/lora (#18128)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-14 11:57:59 +01:00
committed by GitHub
parent 9ccc6ded42
commit 9b5b39b650
19 changed files with 245 additions and 251 deletions

View File

@ -78,7 +78,6 @@ exclude = [
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
import torch.nn as nn
@ -118,7 +118,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
"""
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
@ -165,7 +165,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
@ -201,7 +201,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
@ -248,7 +248,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
@ -281,7 +281,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
@ -341,7 +341,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator

View File

@ -3,7 +3,7 @@
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
import torch.nn as nn
@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping):
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism."""
...
@ -128,7 +128,7 @@ class BaseLayerWithLoRA(nn.Module):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_slice: Optional[tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
@ -279,7 +279,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@ -296,9 +296,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: Tuple[int, ...]
self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int
self.n_slices: int
@ -365,7 +365,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[s_index][index] = 0
@ -399,7 +399,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
@ -497,7 +497,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
@ -597,7 +597,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
@ -674,13 +674,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for output_size in self.output_slices)
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
self, lora_b: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None:
@ -689,8 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
self, bias: list[Union[torch.Tensor,
None]]) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (bias_i := bias[i]) is not None:
@ -725,7 +725,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_b_i.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
@ -740,7 +740,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is MergedColumnParallelLinear
@ -809,7 +809,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1
@ -869,7 +869,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
@ -923,7 +923,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
- output
- bias
"""
# Set up backprop all-reduce.
# set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
@ -958,7 +958,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is RowParallelLinear
@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
dtype: torch.dtype, device: torch.device,
sharded_to_full_mapping: Optional[List[int]]) -> None:
sharded_to_full_mapping: Optional[list[int]]) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
@ -1189,7 +1189,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# Special handling for the LogitsProcessor.
@ -1256,7 +1256,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
@ -1265,7 +1265,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
def scaling_factor_to_offset(self) -> dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
@ -1273,7 +1273,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Sequence as GenericSequence
from collections.abc import Sequence as GenericSequence
from typing import Optional
import torch
import torch.types
@ -125,11 +125,11 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self,
module_name: str,
rank: int,
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
bias: Optional[List[Optional[torch.Tensor]]] = None,
scaling: Optional[List[float]] = None,
lora_alphas: list[Optional[int]],
lora_a: list[Optional[torch.Tensor]],
lora_b: list[Optional[torch.Tensor]],
bias: Optional[list[Optional[torch.Tensor]]] = None,
scaling: Optional[list[float]] = None,
) -> None:
super().__init__(
module_name=module_name,

View File

@ -4,9 +4,9 @@ import copy
import math
import os
import re
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)
from typing import Any, Callable, Optional, Union
import safetensors.torch
import torch
@ -44,12 +44,12 @@ _GLOBAL_LORA_ID = 0
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: List[float]
scaling_factors: list[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
def get_lora_id():
@ -65,7 +65,7 @@ class LoRAModel(AdapterModel):
self,
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
loras: dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None:
"""
@ -84,7 +84,7 @@ class LoRAModel(AdapterModel):
lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
@ -113,19 +113,19 @@ class LoRAModel(AdapterModel):
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: Dict[str, torch.Tensor],
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
embeddings: Optional[dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name, weights_mapper)
@ -187,15 +187,15 @@ class LoRAModel(AdapterModel):
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
expected_lora_modules: list[str],
peft_helper: PEFTHelper,
*,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
@ -220,9 +220,9 @@ class LoRAModel(AdapterModel):
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
unexpected_modules: List[Union[list[str], str]]
unexpected_modules: list[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
tensors: dict[str, torch.Tensor] = {}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
@ -329,7 +329,7 @@ class LoRAModelManager(AdapterModelManager):
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper(
@ -339,7 +339,7 @@ class LoRAModelManager(AdapterModelManager):
max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
self.scaling_factor_to_offset: dict[float, int] = {}
super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model)
@ -358,9 +358,9 @@ class LoRAModelManager(AdapterModelManager):
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a Set for compatibility with LRUCache.
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
@ -530,7 +530,7 @@ class LoRAModelManager(AdapterModelManager):
lora_id: int,
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
@ -578,7 +578,7 @@ class LoRAModelManager(AdapterModelManager):
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras: List[Optional[LoRALayerWeights]] = []
subloras: list[Optional[LoRALayerWeights]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
@ -630,8 +630,8 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
replacement_loras: list[Optional[LoRALayerWeights]] = []
replaced_module: set[str] = set()
has_replacement = False
for r in new_module_names:
lora = self._get_lora_layer_weights(lora_model, r)
@ -694,7 +694,7 @@ class LoRAModelManager(AdapterModelManager):
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
def list_adapters(self) -> dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
@ -721,7 +721,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, LoRAModel]:
def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_adapters.cache)
@ -786,7 +786,7 @@ def create_lora_manager(
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "packed_modules_mapping"):

View File

@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
@ -127,7 +125,7 @@ def _lora_expand_kernel(
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[
lora_b_weights: list[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.
Tensor, # shape [num_tokens, hidden_size * num_slices]
@ -143,7 +141,7 @@ def _lora_expand(
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
@ -264,7 +262,7 @@ def _lora_expand(
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,

View File

@ -4,7 +4,7 @@ LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Union
import torch
@ -125,7 +125,7 @@ class LoRAKernelMeta:
def meta_args(
self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""
This function returns the kernel metadata required for the current

View File

@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
@ -98,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[
lora_a_weights: list[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
@ -112,7 +110,7 @@ def _lora_shrink(
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (List[torch.Tensor]): LoRA weights
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
@ -219,7 +217,7 @@ def _lora_shrink(
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,

View File

@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple
import torch
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
@ -53,7 +51,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
device: torch.device):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,

View File

@ -6,7 +6,7 @@ import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import List, Literal, Optional, Union
from typing import Literal, Optional, Union
from vllm.config import LoRAConfig
from vllm.logger import init_logger
@ -40,7 +40,7 @@ class PEFTHelper:
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self) -> List[str]:
def _validate_features(self) -> list[str]:
"""
Check if there are any unsupported LoRA features.
"""

View File

@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC):
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
@ -43,9 +43,9 @@ class PunicaWrapperABC(ABC):
@abstractmethod
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> Optional[torch.Tensor]:
@ -59,10 +59,10 @@ class PunicaWrapperABC(ABC):
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
@ -91,13 +91,13 @@ class PunicaWrapperABC(ABC):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
@ -150,7 +150,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5
self.indices_len: list[Optional[int]] = [None] * 5
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long,
@ -171,7 +171,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
@ -228,8 +228,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
@ -259,7 +259,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
@property
def prefill_metadata(
self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
@ -323,7 +323,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
@ -341,8 +341,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self.is_prefill = False
@abstractmethod
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
@ -352,9 +352,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
@ -364,10 +364,10 @@ class PunicaWrapperBase(PunicaWrapperABC):
@abstractmethod
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> Optional[torch.Tensor]:
@ -384,11 +384,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
@ -422,13 +422,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
@ -445,12 +445,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
@ -150,8 +150,8 @@ class PunicaWrapperCPU(PunicaWrapperBase):
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
@ -165,9 +165,9 @@ class PunicaWrapperCPU(PunicaWrapperBase):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
@ -179,10 +179,10 @@ class PunicaWrapperCPU(PunicaWrapperBase):
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
@ -198,11 +198,11 @@ class PunicaWrapperCPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
@ -250,13 +250,13 @@ class PunicaWrapperCPU(PunicaWrapperBase):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
@ -273,12 +273,12 @@ class PunicaWrapperCPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)

View File

@ -6,7 +6,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
from typing import TYPE_CHECKING, Optional, Union, final
import torch
@ -57,7 +57,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
@ -74,7 +74,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
lora_a_stacked: tuple[torch.Tensor,
...], scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
@ -86,7 +86,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
@ -102,9 +102,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def add_expand(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
@ -121,10 +121,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
@ -181,11 +181,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
@ -204,11 +204,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
from typing import TYPE_CHECKING, Optional, Union, final
import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
@ -28,7 +28,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
@ -48,9 +48,9 @@ class PunicaWrapperHPU(PunicaWrapperBase):
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
index_mapping_indices: List[int] = list(
index_mapping_indices: list[int] = list(
mapping.index_mapping).copy()
long_lora_offsets: List[int] = []
long_lora_offsets: list[int] = []
for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
@ -85,13 +85,13 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
@ -122,9 +122,9 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
@ -133,10 +133,10 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn.functional as F
@ -77,8 +77,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
@ -88,9 +88,9 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
@ -106,10 +106,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> torch.Tensor:
@ -125,11 +125,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
@ -177,13 +177,13 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
output_slices: tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> torch.Tensor:
"""
Applicable to linear-related lora.
@ -200,12 +200,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
@ -284,8 +284,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -12,7 +12,7 @@ if TYPE_CHECKING:
def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
@ -43,14 +43,14 @@ def compute_meta(
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], list[int]]:
"""Converts LoRAMapping to index tensors.
Args:
@ -84,7 +84,7 @@ def convert_mapping(
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
@ -92,7 +92,7 @@ def convert_mapping(
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
@ -109,7 +109,7 @@ def convert_mapping(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
indices_list: list[Union[list[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from typing import AbstractSet, Dict, Optional
from typing import Optional
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -40,9 +41,9 @@ class LoRAResolver(ABC):
@dataclass
class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
resolvers: dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> AbstractSet[str]:
def get_supported_resolvers(self) -> Set[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()

View File

@ -2,7 +2,7 @@
import os
import re
from typing import List, Optional, Set, Tuple, Type, Union
from typing import Optional, Union
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
@ -37,7 +37,7 @@ from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
@ -58,7 +58,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
packed_modules_list: list,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
@ -99,7 +99,7 @@ def replace_submodule(model: nn.Module, module_name: str,
def parse_fine_tuned_lora_name(
name: str,
weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]:
) -> tuple[str, bool, bool]:
"""Parse the name of lora weights.
args:
@ -108,7 +108,7 @@ def parse_fine_tuned_lora_name(
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
Tuple(module_name, is_lora_a):
tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
@ -147,8 +147,8 @@ def parse_fine_tuned_lora_name(
raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, List[str]],
expected_lora_modules: List[str]) -> bool:
def is_regex_target_modules(load_modules: Union[str, list[str]],
expected_lora_modules: list[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
@ -179,11 +179,11 @@ def is_regex_target_modules(load_modules: Union[str, List[str]],
return False
def get_supported_lora_modules(model: nn.Module) -> List[str]:
def get_supported_lora_modules(model: nn.Module) -> list[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: Set[str] = set()
supported_lora_modules: set[str] = set()
# step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
from typing import Any, Literal, Optional, Union
import torch
@ -27,7 +27,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_manager_cls: Type[LoRAModelManager] = LoRAModelManager
_manager_cls: type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
@ -36,9 +36,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
embedding_modules: Dict[str, str],
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
embedding_modules: dict[str, str],
embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
):
self._lora_model_cls = lora_model_cls
@ -88,7 +88,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self._adapter_manager.supported_lora_modules)
packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping)
expected_lora_modules: List[str] = []
expected_lora_modules: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(
@ -162,12 +162,12 @@ class WorkerLoRAManager(AbstractWorkerManager):
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any],
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
@ -184,7 +184,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
def list_adapters(self) -> set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
@ -195,7 +195,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
_manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
@ -213,7 +213,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self._adapter_manager = lora_manager
return lora_manager.model
def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request