[Model] use AutoWeightsLoader for gpt2 (#18625)
Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -235,6 +235,35 @@ class GPT2Model(nn.Module):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||||
|
# Skip attention mask.
|
||||||
|
# NOTE: "c_attn.bias" should not be skipped.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
|
# Because of this, we need to transpose the weights.
|
||||||
|
# Note(zhuohan): the logic below might break quantized models.
|
||||||
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
|
if conv1d_weight_name not in name:
|
||||||
|
continue
|
||||||
|
if not name.endswith(".weight"):
|
||||||
|
continue
|
||||||
|
loaded_weight = loaded_weight.t()
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||||
|
|
||||||
@ -283,32 +312,16 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_params: set[str] = set()
|
weights = _add_transformer_prefix(weights)
|
||||||
for name, loaded_weight in weights:
|
return loader.load_weights(weights)
|
||||||
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
|
||||||
# Skip attention mask.
|
|
||||||
# NOTE: "c_attn.bias" should not be skipped.
|
|
||||||
continue
|
|
||||||
if not name.startswith("transformer.") and not name.startswith(
|
|
||||||
"lm_head"):
|
|
||||||
name = "transformer." + name
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
def _add_transformer_prefix(
|
||||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
weights: Iterable[tuple[str, torch.Tensor]]
|
||||||
# Because of this, we need to transpose the weights.
|
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||||
# Note(zhuohan): the logic below might break quantized models.
|
for name, tensor in weights:
|
||||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
if not name.startswith('transformer.') and not name.startswith(
|
||||||
if conv1d_weight_name not in name:
|
"lm_head"):
|
||||||
continue
|
name = 'transformer.' + name
|
||||||
if not name.endswith(".weight"):
|
yield name, tensor
|
||||||
continue
|
|
||||||
loaded_weight = loaded_weight.t()
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
Reference in New Issue
Block a user