[Model] Add FlexOlmo model implementation (#24923)
Signed-off-by: Shane A <shanea@allenai.org>
This commit is contained in:
@ -363,6 +363,7 @@ th {
|
|||||||
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
|
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
|
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -250,6 +250,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
|
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
|
||||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||||
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
|
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
|
||||||
|
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
|
||||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||||
|
|||||||
157
vllm/model_executor/models/flex_olmo.py
Normal file
157
vllm/model_executor/models/flex_olmo.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only FlexOlmo model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
|
||||||
|
from vllm.transformers_utils.configs import FlexOlmoConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlexOlmoAttention(OlmoeAttention):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
assert isinstance(hf_config, FlexOlmoConfig)
|
||||||
|
|
||||||
|
self.k_norm = RMSNorm(
|
||||||
|
self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.q_norm = RMSNorm(
|
||||||
|
self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlexOlmoMoE(nn.Module):
|
||||||
|
"""A tensor-parallel MoE implementation for FlexOlmo that shards each expert
|
||||||
|
across all ranks.
|
||||||
|
|
||||||
|
Each expert's weights are sharded across all ranks and a fused MoE
|
||||||
|
kernel is used for the forward pass, and finally we reduce the outputs
|
||||||
|
across ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
assert isinstance(hf_config, FlexOlmoConfig)
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# Gate always runs at half / full precision for now.
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
hf_config.hidden_size,
|
||||||
|
hf_config.num_experts,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False,
|
||||||
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gate always runs at half / full precision for now.
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=hf_config.num_experts,
|
||||||
|
top_k=hf_config.num_experts_per_tok,
|
||||||
|
hidden_size=hf_config.hidden_size,
|
||||||
|
intermediate_size=hf_config.intermediate_size,
|
||||||
|
reduce_results=True,
|
||||||
|
renormalize=False,
|
||||||
|
quant_config=None,
|
||||||
|
tp_size=tp_size,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.top_k = hf_config.num_experts_per_tok
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_dim = hidden_states.shape[-1]
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
# Warning: The experts mutate the hidden state input! This messes up
|
||||||
|
# basic things like the residual stream.
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states.detach().clone(),
|
||||||
|
router_logits=router_logits.float(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class FlexOlmoDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
assert isinstance(hf_config, FlexOlmoConfig)
|
||||||
|
|
||||||
|
self.self_attn = FlexOlmoAttention(
|
||||||
|
vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
hf_config.hidden_size, eps=hf_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_feedforward_layernorm = RMSNorm(
|
||||||
|
hf_config.hidden_size, eps=hf_config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Attention block.
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
# MLP block.
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
class FlexOlmoForCausalLM(OlmoeForCausalLM):
|
||||||
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: type[nn.Module] = FlexOlmoDecoderLayer,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||||
@ -17,15 +17,14 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import OlmoeConfig
|
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
@ -117,20 +116,21 @@ class OlmoeMoE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OlmoeAttention(nn.Module):
|
class OlmoeAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
rope_theta: float = 10000,
|
|
||||||
rope_scaling: Optional[dict[str, Any]] = None,
|
|
||||||
max_position_embeddings: int = 4096,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||||
|
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
|
num_kv_heads = config.num_key_value_heads
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
@ -145,7 +145,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_size % self.total_num_kv_heads == 0
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
@ -153,7 +153,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
self.hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
@ -166,7 +166,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
|
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
@ -218,28 +218,15 @@ class OlmoeAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OlmoeDecoderLayer(nn.Module):
|
class OlmoeDecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
self,
|
|
||||||
config: OlmoeConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
|
||||||
|
|
||||||
self.self_attn = OlmoeAttention(
|
self.self_attn = OlmoeAttention(
|
||||||
hidden_size=self.hidden_size,
|
vllm_config=vllm_config,
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
num_kv_heads=config.num_key_value_heads,
|
|
||||||
rope_theta=rope_theta,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -280,12 +267,16 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class OlmoeModel(nn.Module):
|
class OlmoeModel(nn.Module):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: type[nn.Module] = OlmoeDecoderLayer,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -295,9 +286,7 @@ class OlmoeModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: OlmoeDecoderLayer(
|
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
||||||
config, cache_config, quant_config, prefix=prefix
|
|
||||||
),
|
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
@ -339,7 +328,10 @@ class OlmoeModel(nn.Module):
|
|||||||
{"hidden_states": hidden_states, "residual": residual}
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if residual is not None:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
else:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: type[nn.Module] = OlmoeDecoderLayer,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = OlmoeModel(
|
self.model = OlmoeModel(
|
||||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"),
|
||||||
|
layer_type=layer_type,
|
||||||
)
|
)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
|
|||||||
@ -90,6 +90,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
|
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
|
||||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
||||||
|
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
|
||||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||||
|
|||||||
@ -74,6 +74,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
|||||||
deepseek_vl_v2="DeepseekVLV2Config",
|
deepseek_vl_v2="DeepseekVLV2Config",
|
||||||
deepseek_v3="DeepseekV3Config",
|
deepseek_v3="DeepseekV3Config",
|
||||||
deepseek_v32="DeepseekV3Config",
|
deepseek_v32="DeepseekV3Config",
|
||||||
|
flex_olmo="FlexOlmoConfig",
|
||||||
kimi_vl="KimiVLConfig",
|
kimi_vl="KimiVLConfig",
|
||||||
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
|
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
|
||||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
|||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
|
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
|
||||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||||
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
|
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
|
||||||
@ -45,6 +46,7 @@ __all__ = [
|
|||||||
"DeepseekV3Config",
|
"DeepseekV3Config",
|
||||||
"DotsOCRConfig",
|
"DotsOCRConfig",
|
||||||
"EAGLEConfig",
|
"EAGLEConfig",
|
||||||
|
"FlexOlmoConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
"Lfm2MoeConfig",
|
"Lfm2MoeConfig",
|
||||||
|
|||||||
77
vllm/transformers_utils/configs/flex_olmo.py
Normal file
77
vllm/transformers_utils/configs/flex_olmo.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class FlexOlmoConfig(PretrainedConfig):
|
||||||
|
model_type = "flex_olmo"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=100352,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-06,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=100277,
|
||||||
|
bos_token_id=None,
|
||||||
|
eos_token_id=100257,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=500000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_experts_per_tok=5,
|
||||||
|
num_experts=7,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.01,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if "architectures" not in kwargs:
|
||||||
|
kwargs["architectures"] = ["FlexOlmoForCausalLM"]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
Reference in New Issue
Block a user