From 720af6ab791164175eca32c67de7cfe2994642fc Mon Sep 17 00:00:00 2001 From: Roger Young <42564206+rogeryoungh@users.noreply.github.com> Date: Mon, 27 Oct 2025 00:59:11 +0800 Subject: [PATCH] [Model][MiniMax-M2] Support MiniMax-M2 Model (#27535) Signed-off-by: xuebi Co-authored-by: xuebi --- tests/models/registry.py | 3 + .../openai/tool_parsers/__init__.py | 2 + .../tool_parsers/minimax_m2_tool_parser.py | 644 ++++++++++++++++++ vllm/model_executor/models/minimax_m2.py | 585 ++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/reasoning/__init__.py | 2 + vllm/reasoning/minimax_m2_reasoning_parser.py | 69 ++ 7 files changed, 1306 insertions(+) create mode 100644 vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py create mode 100644 vllm/model_executor/models/minimax_m2.py create mode 100644 vllm/reasoning/minimax_m2_reasoning_parser.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 8e11ee755b..f227edd827 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -341,6 +341,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MiniMaxM1ForCausalLM": _HfExamplesInfo( "MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True ), + "MiniMaxM2ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-M2", trust_remote_code=True + ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MixtralForCausalLM": _HfExamplesInfo( "mistralai/Mixtral-8x7B-Instruct-v0.1", diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index a72772f59c..4541ca5082 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -16,6 +16,7 @@ from .kimi_k2_tool_parser import KimiK2ToolParser from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .longcat_tool_parser import LongcatFlashToolParser +from .minimax_m2_tool_parser import MinimaxM2ToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser from .olmo3_tool_parser import Olmo3PythonicToolParser @@ -56,4 +57,5 @@ __all__ = [ "SeedOssToolParser", "Step3ToolParser", "OpenAIToolParser", + "MinimaxM2ToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py new file mode 100644 index 0000000000..06dd336bf9 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py @@ -0,0 +1,644 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("minimax_m2") +class MinimaxM2ToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.prev_tool_call_arr: list[dict] = [] + + # Sentinel tokens + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.invoke_start_prefix: str = "" + self.parameter_prefix: str = "" + + # Streaming state variables + self.current_tool_name_sent: bool = False + # Override base class type - we use string IDs for tool calls + self.current_tool_id: str | None = None # type: ignore + self.streamed_args_for_tool: list[str] = [] + self.is_tool_call_started: bool = False + self.failed_count: int = 0 + + # Initialize streaming state variables + self.current_tool_index: int = 0 + self.invoke_index: int = 0 + self.header_sent: bool = False + self.current_function_name: str | None = None + self.current_param_name: str | None = None + self.current_param_value: str = "" + self.param_count: int = 0 + self.in_param: bool = False + self.in_function: bool = False + self.accumulated_text: str = "" + self.json_started: bool = False + self.json_closed: bool = False + self.accumulated_params: dict = {} + self.streaming_request: ChatCompletionRequest | None = None + + # Enhanced streaming state - reset for each new message + self._reset_streaming_state() + + # Regex patterns for complete parsing + self.tool_call_complete_regex = re.compile( + r"(.*?)", re.DOTALL + ) + self.invoke_complete_regex = re.compile( + r"", re.DOTALL + ) + self.parameter_complete_regex = re.compile( + r"", re.DOTALL + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: + raise RuntimeError( + "MiniMax M2 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!" + ) + + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.invoke_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = None + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + # Clear previous tool call history to avoid state pollution + self.prev_tool_call_arr.clear() + + def _extract_name(self, name_str: str) -> str: + """Extract name from quoted string.""" + name_str = name_str.strip() + if ( + name_str.startswith('"') + and name_str.endswith('"') + or name_str.startswith("'") + and name_str.endswith("'") + ): + return name_str[1:-1] + return name_str + + def _convert_param_value(self, value: str, param_type: str) -> Any: + """Convert parameter value to the correct type.""" + if value.lower() == "null": + return None + + param_type = param_type.lower() + if param_type in ["string", "str", "text"]: + return value + elif param_type in ["integer", "int"]: + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ["number", "float"]: + try: + val = float(value) + return val if val != int(val) else int(val) + except (ValueError, TypeError): + return value + elif param_type in ["boolean", "bool"]: + return value.lower() in ["true", "1"] + elif param_type in ["object", "array"]: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Try JSON parse first, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def _parse_single_invoke( + self, invoke_str: str, tools: list | None + ) -> ToolCall | None: + """Parse a single block.""" + # Extract function name + name_match = re.search(r"^([^>]+)", invoke_str) + if not name_match: + return None + + function_name = self._extract_name(name_match.group(1)) + + # Get parameter configuration + param_config = {} + if tools: + for tool in tools: + if ( + hasattr(tool, "function") + and tool.function.name == function_name + and hasattr(tool.function, "parameters") + ): + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + param_config = params["properties"] + break + + # Extract parameters + param_dict = {} + for match in self.parameter_complete_regex.findall(invoke_str): + param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL) + if param_match: + param_name = self._extract_name(param_match.group(1)) + param_value = param_match.group(2).strip() + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Get parameter type + param_type = "string" + if ( + param_name in param_config + and isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): + param_type = param_config[param_name]["type"] + + # Convert value + param_dict[param_name] = self._convert_param_value( + param_value, param_type + ) + + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, + arguments=json.dumps(param_dict, ensure_ascii=False), + ), + ) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """Extract tool calls from complete model output (non-streaming).""" + # Quick check + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + tool_calls = [] + + # Find all complete tool_call blocks + for tool_call_match in self.tool_call_complete_regex.findall(model_output): + # Find all invokes within this tool_call + for invoke_match in self.invoke_complete_regex.findall(tool_call_match): + tool_call = self._parse_single_invoke( + invoke_match, request.tools if request else None + ) + if tool_call: + tool_calls.append(tool_call) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # Update prev_tool_call_arr + self.prev_tool_call_arr.clear() + for tool_call in tool_calls: + self.prev_tool_call_arr.append( + { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) + + # Extract content before first tool call + first_tool_idx = model_output.find(self.tool_call_start_token) + content = model_output[:first_tool_idx] if first_tool_idx > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + + except Exception: + logger.exception("Error extracting tool calls") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], # pylint: disable=unused-argument + current_token_ids: Sequence[int], # pylint: disable=unused-argument + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + """Extract tool calls from streaming model output.""" + + # Store request for type conversion + if not previous_text or self.tool_call_start_token in delta_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text) + ) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) + if open_calls == 0: + # Return empty delta for finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + invoke_ends = current_text.count(self.invoke_end_token) + if invoke_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + self.in_function = False # Now we can safely set this to False + self.accumulated_params = {} + # Continue processing next tool + return None + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if ( + self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text + ): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + invoke_starts_count = current_text.count(self.invoke_start_prefix) + if self.current_tool_index >= invoke_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # Find the current tool call portion + invoke_start_positions: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.invoke_start_prefix, idx) + if idx == -1: + break + invoke_start_positions.append(idx) + idx += len(self.invoke_start_prefix) + + if self.current_tool_index >= len(invoke_start_positions): + # No more tool calls to process yet + return None + + invoke_start_idx = invoke_start_positions[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx) + if invoke_end_idx == -1: + tool_text = current_text[invoke_start_idx:] + else: + tool_text = current_text[ + invoke_start_idx : invoke_end_idx + len(self.invoke_end_token) + ] + + # Looking for function header + if not self.header_sent: + if self.invoke_start_prefix in tool_text: + func_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + # Find the end quote for the function name + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + function_name_raw = tool_text[func_start:func_end] + self.current_function_name = self._extract_name(function_name_raw) + self.current_tool_id = self._generate_tool_call_id() + self.header_sent = True + self.in_function = True + + # Add to prev_tool_call_arr immediately when we detect a tool call + # Each tool call should be recorded regardless of function name + # Ensure we don't add the same tool call index multiple times + if len(self.prev_tool_call_arr) <= self.current_tool_index: + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) + + # Send header with function info + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if self.in_function and not self.json_started: + self.json_started = True + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.invoke_end_token in tool_text: + # Count total parameters in the tool text + total_param_count = tool_text.count(self.parameter_prefix) + + # Only close JSON if all parameters have been processed + if self.param_count >= total_param_count: + # Close JSON + self.json_closed = True + + # Extract complete tool call + # Find the invoke content + invoke_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + invoke_content_end = tool_text.find( + self.invoke_end_token, invoke_start + ) + if invoke_content_end != -1: + invoke_content = tool_text[invoke_start:invoke_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_single_invoke( + invoke_content, + self.streaming_request.tools + if self.streaming_request + else None, + ) + if parsed_tool and self.current_tool_index < len( + self.prev_tool_call_arr + ): + # Update existing entry in prev_tool_call_arr + args = parsed_tool.function.arguments + self.prev_tool_call_arr[self.current_tool_index][ + "arguments" + ] = args + except Exception: + pass # Ignore parsing errors during streaming + + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) + + # Reset state for next tool + self.json_closed = True + self.in_function = False + self.accumulated_params = {} + + logger.debug("[M2_STREAMING] Tool call completed") + + return result + else: + # Don't close JSON yet, continue processing parameters + return None + + # Look for parameters + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + # Check if we should start a new parameter + if ( + not self.in_param + and self.param_count < len(param_starts) + and len(param_starts) > self.param_count + ): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + param_name_raw = remaining[:name_end] + self.current_param_name = self._extract_name(param_name_raw) + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.invoke_end_token) + + if next_param_idx != -1 and ( + func_end_idx == -1 or next_param_idx < func_end_idx + ): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.invoke_end_token in tool_text: + # Tool call and parameter is complete + param_end_idx = len(value_text) + else: + # Still streaming, wait for more content + return None + + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Store raw value for later processing + self.accumulated_params[self.current_param_name] = param_value + + # Get parameter configuration for type conversion + param_config = {} + if self.streaming_request and self.streaming_request.tools: + for tool in self.streaming_request.tools: + if ( + hasattr(tool, "function") + and tool.function.name == self.current_function_name + and hasattr(tool.function, "parameters") + ): + params = tool.function.parameters + if ( + isinstance(params, dict) + and "properties" in params + ): + param_config = params["properties"] + break + + # Get parameter type + param_type = "string" + if ( + self.current_param_name in param_config + and isinstance(param_config[self.current_param_name], dict) + and "type" in param_config[self.current_param_name] + ): + param_type = param_config[self.current_param_name]["type"] + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, param_type + ) + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) + + if self.param_count == 0: + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) + else: + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) + + self.param_count += 1 + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=json_fragment), + ) + ] + ) + + return None diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py new file mode 100644 index 0000000000..d122adfafb --- /dev/null +++ b/vllm/model_executor/models/minimax_m2.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniMaxM2 model.""" + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class MiniMaxM2MoE(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_local_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_local_experts}." + ) + self.use_routing_bias = getattr(config, "use_routing_bias", False) + if self.use_routing_bias: + self.e_score_correction_bias = nn.Parameter( + torch.empty(config.num_local_experts, dtype=torch.float32) + ) + self.e_score_correction_bias.weight_loader = ( + MiniMaxM2MoE.ebias_weight_loader + ) + else: + self.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + scoring_func=config.scoring_func, + use_grouped_topk=True, + num_expert_group=1, + topk_group=1, + e_score_correction_bias=self.e_score_correction_bias, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=False, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + @staticmethod + def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight.to(torch.float32)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states.to(torch.float32)) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + final_hidden_states = final_hidden_states + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class MiniMaxM2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rotary_dim: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + attn_window_size: int | None = None, + max_position_embeddings: int = 8192, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + per_layer_sliding_window=attn_window_size, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + self.q_norm = MiniMaxText01RMSNormTP( + self.head_dim * self.total_num_heads, eps=rms_norm_eps + ) + self.k_norm = MiniMaxText01RMSNormTP( + self.head_dim * self.total_num_kv_heads, eps=rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q) + k = self.k_norm(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MiniMaxM2DecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): + max_position_embeddings = max( + config.max_position_embeddings, config.max_model_len + ) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep=".")[-1]) + + # TODO: support MTP + attn_window_size = getattr(config, "attn_window_size", None) + if attn_window_size is not None: + if isinstance(attn_window_size, list): + attn_window_size = attn_window_size[layer_idx] + elif isinstance(attn_window_size, int): + attn_window_size = attn_window_size + else: + raise ValueError(f"Invalid attn_window_size: {attn_window_size}") + attn_window_size = None if attn_window_size <= 0 else attn_window_size + + # different rope theta for full layer and swa layer + swa_rope_theta = getattr(config, "swa_rope_theta", -1) + # default to full rope theta + swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta + rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta + + self.layer_idx = layer_idx + self.self_attn = MiniMaxM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rotary_dim=config.rotary_dim, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + attn_window_size=attn_window_size, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.block_sparse_moe = MiniMaxM2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + hidden_states = self.block_sparse_moe(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class MiniMaxM2Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=None, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniMaxM2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = self.get_expert_mapping() + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class MiniMaxM2ForCausalLM(nn.Module, SupportsPP): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + if hasattr(vllm_config.model_config, "max_model_len"): + self.config.max_model_len = vllm_config.model_config.max_model_len + self.model = MiniMaxM2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=None + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def get_spec_layer_idx_from_weight_name( + config: PretrainedConfig, weight_name: str +) -> int | None: + if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_mtp_modules): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 81d4a6bc5f..e8212ef6d7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -131,6 +131,7 @@ _TEXT_GENERATION_MODELS = { "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), + "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index ecee1af439..3d666882ef 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -11,6 +11,7 @@ from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .identity_reasoning_parser import IdentityReasoningParser +from .minimax_m2_reasoning_parser import MiniMaxM2ReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .olmo3_reasoning_parser import Olmo3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser @@ -34,4 +35,5 @@ __all__ = [ "Step3ReasoningParser", "GptOssReasoningParser", "SeedOSSReasoningParser", + "MiniMaxM2ReasoningParser", ] diff --git a/vllm/reasoning/minimax_m2_reasoning_parser.py b/vllm/reasoning/minimax_m2_reasoning_parser.py new file mode 100644 index 0000000000..0d4f6cc270 --- /dev/null +++ b/vllm/reasoning/minimax_m2_reasoning_parser.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) +from vllm.logger import init_logger +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("minimax_m2") +class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for MiniMax M2 model. + """ + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" + + +@ReasoningParserManager.register_module("minimax_m2_append_think") +class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): + """ + Reasoning parser for MiniMax M2 model. + """ + + def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.end_token_id = self.vocab.get("") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + end_token_id = self.end_token_id + return any(input_id == end_token_id for input_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(previous_token_ids) == 0: + delta_text = "" + delta_text + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest | ResponsesRequest + ) -> tuple[str | None, str | None]: + return None, "" + model_output