[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -1,13 +1,34 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
|
||||
class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
class MyGemma2Embedding(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = super().forward(
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
|
||||
|
||||
# Return all-zero embeddings
|
||||
return torch.zeros_like(hidden_states)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
return self.model.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user