diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c16bdeeecd..13ad3af97d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -164,9 +164,7 @@ repos: name: Validate configuration has default values and that each field has a docstring entry: python tools/validate_config.py language: python - types: [python] - pass_filenames: true - files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py + additional_dependencies: [regex] # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/tools/validate_config.py b/tools/validate_config.py index 8b1e955c65..f6439fa9ad 100644 --- a/tools/validate_config.py +++ b/tools/validate_config.py @@ -9,6 +9,8 @@ import ast import inspect import sys +import regex as re + def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: """ @@ -88,11 +90,12 @@ def validate_class(class_node: ast.ClassDef): for stmt in class_node.body: # A field is defined as a class variable that has a type annotation. if isinstance(stmt, ast.AnnAssign): - # Skip ClassVar + # Skip ClassVar and InitVar # see https://docs.python.org/3/library/dataclasses.html#class-variables - if isinstance(stmt.annotation, ast.Subscript) and isinstance( - stmt.annotation.value, - ast.Name) and stmt.annotation.value.id == "ClassVar": + # and https://docs.python.org/3/library/dataclasses.html#init-only-variables + if (isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id in {"ClassVar", "InitVar"}): continue if isinstance(stmt.target, ast.Name): @@ -132,7 +135,7 @@ def validate_ast(tree: ast.stmt): def validate_file(file_path: str): try: - print(f"validating {file_path} config dataclasses ", end="") + print(f"Validating {file_path} config dataclasses ", end="") with open(file_path, encoding="utf-8") as f: source = f.read() @@ -140,7 +143,7 @@ def validate_file(file_path: str): validate_ast(tree) except ValueError as e: print(e) - SystemExit(2) + raise SystemExit(1) from e else: print("✅") @@ -151,7 +154,13 @@ def fail(message: str, node: ast.stmt): def main(): for filename in sys.argv[1:]: - validate_file(filename) + # Only run for Python files in vllm/ or tests/ + if not re.match(r"^(vllm|tests)/.*\.py$", filename): + continue + # Only run if the file contains @config + with open(filename, encoding="utf-8") as f: + if "@config" in f.read(): + validate_file(filename) if __name__ == "__main__": diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 69ab5712d4..25daca00c0 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -450,6 +450,8 @@ class ModelConfig: # Multimodal config and init vars multimodal_config: Optional[MultiModalConfig] = None + """Configuration for multimodal model. If `None`, this will be inferred + from the architecture of `self.model`.""" limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None