diff --git a/tests/test_config.py b/tests/test_config.py index f3d40a7d80..bba2fbec3d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -292,6 +292,37 @@ def test_rope_customization(): assert longchat_model_config.max_model_len == 4096 +def test_nested_hf_overrides(): + """Test that nested hf_overrides work correctly.""" + # Test with a model that has text_config + model_config = ModelConfig( + "Qwen/Qwen2-VL-2B-Instruct", + hf_overrides={ + "text_config": { + "hidden_size": 1024, + }, + }, + ) + assert model_config.hf_config.text_config.hidden_size == 1024 + + # Test with deeply nested overrides + model_config = ModelConfig( + "Qwen/Qwen2-VL-2B-Instruct", + hf_overrides={ + "text_config": { + "hidden_size": 2048, + "num_attention_heads": 16, + }, + "vision_config": { + "hidden_size": 512, + }, + }, + ) + assert model_config.hf_config.text_config.hidden_size == 2048 + assert model_config.hf_config.text_config.num_attention_heads == 16 + assert model_config.hf_config.vision_config.hidden_size == 512 + + @pytest.mark.skipif( current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm." ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 146ace9782..d0c027e476 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -367,6 +367,51 @@ class ModelConfig: assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() + def _update_nested( + self, + target: Union["PretrainedConfig", dict[str, Any]], + updates: dict[str, Any], + ) -> None: + """Recursively updates a config or dict with nested updates.""" + for key, value in updates.items(): + if isinstance(value, dict): + # Get the nested target + if isinstance(target, dict): + nested_target = target.get(key) + else: + nested_target = getattr(target, key, None) + + # If nested target exists and can be updated recursively + if nested_target is not None and ( + isinstance(nested_target, dict) + or hasattr(nested_target, "__dict__") + ): + self._update_nested(nested_target, value) + continue + + # Set the value (base case) + if isinstance(target, dict): + target[key] = value + else: + setattr(target, key, value) + + def _apply_dict_overrides( + self, + config: "PretrainedConfig", + overrides: dict[str, Any], + ) -> None: + """Apply dict overrides, handling both nested configs and dict values.""" + from transformers import PretrainedConfig + + for key, value in overrides.items(): + attr = getattr(config, key, None) + if attr is not None and isinstance(attr, PretrainedConfig): + # It's a nested config - recursively update it + self._update_nested(attr, value) + else: + # It's a dict-valued parameter - set it directly + setattr(config, key, value) + def __post_init__( self, # Multimodal config init vars @@ -419,8 +464,17 @@ class ModelConfig: if callable(self.hf_overrides): hf_overrides_kw = {} hf_overrides_fn = self.hf_overrides + dict_overrides: dict[str, Any] = {} else: - hf_overrides_kw = self.hf_overrides + # Separate dict overrides from flat ones + # We'll determine how to apply dict overrides after loading the config + hf_overrides_kw = {} + dict_overrides = {} + for key, value in self.hf_overrides.items(): + if isinstance(value, dict): + dict_overrides[key] = value + else: + hf_overrides_kw[key] = value hf_overrides_fn = None if self.rope_scaling: @@ -478,6 +532,8 @@ class ModelConfig: ) self.hf_config = hf_config + if dict_overrides: + self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None