[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)
This commit is contained in:
@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
@ -687,35 +686,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_model.load_weights(weights_group["vision_model"])
|
||||
|
||||
# load query tokens
|
||||
for name, loaded_weight in weights_group["query_tokens"]:
|
||||
assert name == ""
|
||||
param = self.query_tokens
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load qformer
|
||||
qformer_params_dict = dict(self.qformer.named_parameters())
|
||||
for name, loaded_weight in weights_group["qformer"]:
|
||||
param = qformer_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.language_projection.named_parameters())
|
||||
for name, loaded_weight in weights_group["language_projection"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -42,8 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -349,16 +347,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision embeddings
|
||||
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
|
||||
for name, loaded_weight in weights_group["vision_embed_tokens"]:
|
||||
param = vision_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (group_weights_with_prefix, is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -447,19 +447,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
# NOTE: For now self.lm_head is not defined because
|
||||
# tie_word_embeddings is assumed to the False
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -32,8 +31,8 @@ from vllm.utils import is_list_of
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -609,19 +608,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_model.load_weights(weights_group["vision_model"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.mlp1.named_parameters())
|
||||
for name, loaded_weight in weights_group["mlp1"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -51,8 +51,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, group_weights_with_prefix,
|
||||
is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -564,25 +563,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = [
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -406,19 +405,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -15,7 +15,6 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -29,8 +28,8 @@ from .llava import LlavaMultiModalProjector
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
@ -642,27 +641,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load newline
|
||||
for name, loaded_weight in weights_group["image_newline"]:
|
||||
assert name == ""
|
||||
param = self.image_newline
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
@ -458,19 +457,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# This model doesn't support images for now
|
||||
ignore_unexpected_prefixes=["image_newline"],
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -20,7 +20,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
@ -35,8 +34,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
dummy_video_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -872,19 +871,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -11,7 +11,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma import GemmaForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
|
||||
from .utils import AutoWeightsLoader, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -292,19 +291,5 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision tower
|
||||
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -42,15 +41,11 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_embed_tokens": "vision_embed_tokens",
|
||||
}
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 32044
|
||||
|
||||
@ -295,35 +290,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load vision encoder
|
||||
self.img_processor.load_weights(weights_group["img_processor"])
|
||||
|
||||
# load glb_GN
|
||||
for name, loaded_weight in weights_group["glb_GN"]:
|
||||
assert name == ""
|
||||
param = self.glb_GN
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load sub_GN
|
||||
for name, loaded_weight in weights_group["sub_GN"]:
|
||||
assert name == ""
|
||||
param = self.sub_GN
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load mlp projector
|
||||
mlp_params_dict = dict(self.img_projection.named_parameters())
|
||||
for name, loaded_weight in weights_group["img_projection"]:
|
||||
param = mlp_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
@ -715,27 +683,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapping = {
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
}
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def hf_to_vllm_name(key: str) -> str:
|
||||
for hf_name, vllm_name in hf_to_vllm_mapping.items():
|
||||
if key.startswith(hf_name):
|
||||
return key.replace(hf_name, vllm_name, 1)
|
||||
|
||||
return key
|
||||
|
||||
vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
|
||||
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(vllm_weights.items())
|
||||
|
||||
# load vision embeddings and encoder
|
||||
self.vision_embed_tokens.load_weights(
|
||||
weights_group["vision_embed_tokens"])
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
|
||||
@ -48,8 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, group_weights_with_prefix,
|
||||
is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -435,17 +434,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -16,13 +16,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import group_weights_with_prefix
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
@ -120,13 +119,5 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
score_dict = dict(self.score.named_parameters())
|
||||
for name, loaded_weight in weights_group["score"]:
|
||||
param = score_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
@ -25,11 +25,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
group_weights_with_prefix,
|
||||
init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||
@ -41,6 +36,8 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
@ -498,30 +495,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
||||
|
||||
# load audio tower weights
|
||||
audio_tower_weights = weights_group["audio_tower"]
|
||||
audio_tower_params_dict = dict(
|
||||
self.audio_tower.named_parameters(
|
||||
prefix=self.audio_tower.base_model_prefix))
|
||||
for name, loaded_weight in audio_tower_weights:
|
||||
if name in audio_tower_params_dict:
|
||||
param = audio_tower_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load projector weights
|
||||
projector_weights = weights_group["multi_modal_projector"]
|
||||
projector_params_dict = dict(
|
||||
self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in projector_weights:
|
||||
param = projector_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["audio_tower."])
|
||||
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from collections import UserDict
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
|
||||
Tuple, Union, overload)
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.loader import build_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
WeightsMapping = Mapping[str, Optional[str]]
|
||||
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
|
||||
|
||||
class WeightsGroup(UserDict):
|
||||
|
||||
@dataclass
|
||||
class WeightsMapper:
|
||||
"""Maps the name of each weight if they match the following patterns."""
|
||||
|
||||
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
|
||||
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
|
||||
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
|
||||
|
||||
def _map_name(self, key: str) -> Optional[str]:
|
||||
for substr, new_key in self.orig_to_new_substr.items():
|
||||
if substr in key:
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = key.replace(substr, new_key, 1)
|
||||
|
||||
for prefix, new_key in self.orig_to_new_prefix.items():
|
||||
if key.startswith(prefix):
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = key.replace(prefix, new_key, 1)
|
||||
|
||||
for suffix, new_key in self.orig_to_new_suffix.items():
|
||||
if key.endswith(suffix):
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = new_key.join(key.rsplit(suffix, 1))
|
||||
|
||||
return key
|
||||
|
||||
def apply(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
return ((out_name, data) for name, data in weights
|
||||
if (out_name := self._map_name(name)) is not None)
|
||||
|
||||
|
||||
class AutoWeightsLoader:
|
||||
"""
|
||||
Wraps grouped weights dictionary for a more informative error message
|
||||
when attempting to access a weight component that does not exist.
|
||||
Helper class to load weights into a :class:`torch.nn.Module`. It is able
|
||||
to automatically detect child modules and parameters while iterating over
|
||||
the weights only once.
|
||||
|
||||
The weight loading logic for individual modules can be overridden
|
||||
by defining a ``load_weights`` method.
|
||||
|
||||
Similarly, the weight loading logic for individual parameters can be
|
||||
overridden by defining a ``weight_loader`` method.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no weights named with the prefix: {key}. "
|
||||
f"Available prefix: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
*,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.skip_prefixes = skip_prefixes or []
|
||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||
|
||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Helper function to load weights for inner vLLM models.
|
||||
def _groupby_prefix(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
|
||||
weights_by_parts = ((weight_name.split(".", 1), weight_data)
|
||||
for weight_name, weight_data in weights)
|
||||
|
||||
See also:
|
||||
:ref:`init_vllm_registered_model`
|
||||
"""
|
||||
for name, loaded_weight in weights:
|
||||
name = name.split(".")
|
||||
if prefix == name.pop(0):
|
||||
name = ".".join(name)
|
||||
yield name, loaded_weight
|
||||
for prefix, group in itertools.groupby(weights_by_parts,
|
||||
key=lambda x: x[0][0]):
|
||||
yield (
|
||||
prefix,
|
||||
# Because maxsplit=1 in weight_name.split(...),
|
||||
# the length of `parts` must either be 1 or 2
|
||||
(("" if len(parts) == 1 else parts[1], weights_data)
|
||||
for parts, weights_data in group),
|
||||
)
|
||||
|
||||
def _get_qualname(self, prefix: str, rest: str) -> str:
|
||||
if prefix == "":
|
||||
return rest
|
||||
if rest == "":
|
||||
return prefix
|
||||
|
||||
def group_weights_with_prefix(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
|
||||
"""
|
||||
Helper function to group weights with prefix
|
||||
"""
|
||||
init_weights, repeated_weights = itertools.tee(weights, 2)
|
||||
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
|
||||
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
|
||||
return ".".join((prefix, rest))
|
||||
|
||||
return WeightsGroup({
|
||||
prefix: filter_weights(component, prefix)
|
||||
for component, prefix in zip(repeated_weights, weights_prefix)
|
||||
})
|
||||
def _can_skip(self, qualname: str) -> bool:
|
||||
return any(qualname.startswith(p) for p in self.skip_prefixes)
|
||||
|
||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||
return any(
|
||||
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
||||
|
||||
def _load_param(
|
||||
self,
|
||||
base_prefix: str,
|
||||
param: nn.Parameter,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
for weight_name, weight_data in weights:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
|
||||
if self._can_skip(weight_qualname):
|
||||
continue
|
||||
|
||||
if weight_name != "":
|
||||
if not self._can_ignore_unexpected(weight_qualname):
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
|
||||
continue
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight_data)
|
||||
|
||||
def _load_module(
|
||||
self,
|
||||
base_prefix: str,
|
||||
module: nn.Module,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
if isinstance(module, PPMissingLayer):
|
||||
return
|
||||
|
||||
# Avoid infinite recursion since this function is typically
|
||||
# called inside load_weights of the module itself
|
||||
if module != self.module:
|
||||
module_load_weights = getattr(module, "load_weights", None)
|
||||
if callable(module_load_weights):
|
||||
module_load_weights(weights)
|
||||
return
|
||||
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
if self._can_skip(prefix):
|
||||
continue
|
||||
|
||||
if child_prefix in child_modules:
|
||||
self._load_module(prefix, child_modules[child_prefix],
|
||||
child_weights)
|
||||
elif child_prefix in child_params:
|
||||
self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
else:
|
||||
if not self._can_ignore_unexpected(prefix):
|
||||
msg = f"There is no module or parameter named '{prefix}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> None:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
self._load_module("", self.module, weights)
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
|
||||
Reference in New Issue
Block a user