Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -17,12 +17,16 @@ REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
|
||||
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
|
||||
|
||||
# If you need to add items to whitelist, do it here.
|
||||
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
|
||||
"vllm.env_override",
|
||||
})
|
||||
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
|
||||
".version",
|
||||
})
|
||||
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
"vllm.env_override",
|
||||
}
|
||||
)
|
||||
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
".version",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_internal(name: str | None, *, level: int = 0) -> bool:
|
||||
@ -34,8 +38,7 @@ def _is_internal(name: str | None, *, level: int = 0) -> bool:
|
||||
|
||||
|
||||
def _fail(violations: Iterable[tuple[int, str]]) -> None:
|
||||
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
|
||||
file=sys.stderr)
|
||||
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr)
|
||||
for lineno, msg in violations:
|
||||
print(f" Line {lineno}: {msg}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@ -48,7 +51,6 @@ def main() -> None:
|
||||
violations: list[tuple[int, str]] = []
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._in_type_checking = False
|
||||
@ -56,10 +58,10 @@ def main() -> None:
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
guard_is_type_checking = False
|
||||
test = node.test
|
||||
if isinstance(test, ast.Attribute) and isinstance(
|
||||
test.value, ast.Name):
|
||||
guard_is_type_checking = (test.value.id == "typing"
|
||||
and test.attr == "TYPE_CHECKING")
|
||||
if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name):
|
||||
guard_is_type_checking = (
|
||||
test.value.id == "typing" and test.attr == "TYPE_CHECKING"
|
||||
)
|
||||
elif isinstance(test, ast.Name):
|
||||
guard_is_type_checking = test.id == "TYPE_CHECKING"
|
||||
|
||||
@ -79,24 +81,28 @@ def main() -> None:
|
||||
return
|
||||
for alias in node.names:
|
||||
module_name = alias.name
|
||||
if _is_internal(
|
||||
module_name) and module_name not in ALLOWED_IMPORTS:
|
||||
violations.append((
|
||||
node.lineno,
|
||||
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
))
|
||||
if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS:
|
||||
violations.append(
|
||||
(
|
||||
node.lineno,
|
||||
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if self._in_type_checking:
|
||||
return
|
||||
module_as_written = ("." * node.level) + (node.module or "")
|
||||
if _is_internal(
|
||||
node.module, level=node.level
|
||||
) and module_as_written not in ALLOWED_FROM_MODULES:
|
||||
violations.append((
|
||||
node.lineno,
|
||||
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
))
|
||||
if (
|
||||
_is_internal(node.module, level=node.level)
|
||||
and module_as_written not in ALLOWED_FROM_MODULES
|
||||
):
|
||||
violations.append(
|
||||
(
|
||||
node.lineno,
|
||||
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
|
||||
)
|
||||
)
|
||||
|
||||
Visitor().visit(tree)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from enum import Enum
|
||||
|
||||
class SPDXStatus(Enum):
|
||||
"""SPDX header status enumeration"""
|
||||
|
||||
EMPTY = "empty" # empty __init__.py
|
||||
COMPLETE = "complete"
|
||||
MISSING_LICENSE = "missing_license" # Only has copyright line
|
||||
@ -16,7 +17,8 @@ class SPDXStatus(Enum):
|
||||
|
||||
FULL_SPDX_HEADER = (
|
||||
"# SPDX-License-Identifier: Apache-2.0\n"
|
||||
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project")
|
||||
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"
|
||||
)
|
||||
|
||||
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
|
||||
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
|
||||
@ -123,8 +125,9 @@ def main():
|
||||
continue
|
||||
|
||||
# Collect all files that need fixing
|
||||
all_files_to_fix = (files_missing_both + files_missing_copyright +
|
||||
files_missing_license)
|
||||
all_files_to_fix = (
|
||||
files_missing_both + files_missing_copyright + files_missing_license
|
||||
)
|
||||
if all_files_to_fix:
|
||||
print("The following files are missing the SPDX header:")
|
||||
if files_missing_both:
|
||||
|
||||
@ -23,8 +23,7 @@ def is_allowed_file(current_file: str) -> bool:
|
||||
|
||||
def is_forbidden_import(line: str) -> bool:
|
||||
stripped = line.strip()
|
||||
return bool(
|
||||
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
||||
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> list[str]:
|
||||
@ -42,24 +41,24 @@ def parse_diff(diff: str) -> list[str]:
|
||||
elif line.startswith("@@"):
|
||||
match = re.search(r"\+(\d+)", line)
|
||||
if match:
|
||||
current_lineno = int(
|
||||
match.group(1)) - 1 # next "+ line" is here
|
||||
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
|
||||
elif line.startswith("+") and not line.startswith("++"):
|
||||
current_lineno += 1
|
||||
code_line = line[1:]
|
||||
if is_forbidden_import(code_line):
|
||||
violations.append(
|
||||
f"{current_file}:{current_lineno}: {code_line.strip()}")
|
||||
f"{current_file}:{current_lineno}: {code_line.strip()}"
|
||||
)
|
||||
return violations
|
||||
|
||||
|
||||
def get_diff(diff_type: str) -> str:
|
||||
if diff_type == "staged":
|
||||
return subprocess.check_output(
|
||||
["git", "diff", "--cached", "--unified=0"], text=True)
|
||||
["git", "diff", "--cached", "--unified=0"], text=True
|
||||
)
|
||||
elif diff_type == "unstaged":
|
||||
return subprocess.check_output(["git", "diff", "--unified=0"],
|
||||
text=True)
|
||||
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown diff_type: {diff_type}")
|
||||
|
||||
@ -75,8 +74,10 @@ def main():
|
||||
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
||||
|
||||
if all_violations:
|
||||
print("❌ Forbidden direct `import triton` detected."
|
||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
|
||||
print(
|
||||
"❌ Forbidden direct `import triton` detected."
|
||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
|
||||
)
|
||||
for v in all_violations:
|
||||
print(f"❌ {v}")
|
||||
return 1
|
||||
|
||||
@ -7,24 +7,23 @@ from pathlib import Path
|
||||
|
||||
import regex as re
|
||||
|
||||
FORBIDDEN_PATTERNS = re.compile(
|
||||
r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)')
|
||||
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
|
||||
ALLOWED_PATTERNS = [
|
||||
re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'),
|
||||
re.compile(r'^\s*import\s+regex\s*$'),
|
||||
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
|
||||
re.compile(r"^\s*import\s+regex\s*$"),
|
||||
]
|
||||
|
||||
|
||||
def get_staged_python_files() -> list[str]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'],
|
||||
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True)
|
||||
files = result.stdout.strip().split(
|
||||
'\n') if result.stdout.strip() else []
|
||||
return [f for f in files if f.endswith('.py')]
|
||||
check=True,
|
||||
)
|
||||
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
|
||||
return [f for f in files if f.endswith(".py")]
|
||||
except subprocess.CalledProcessError:
|
||||
return []
|
||||
|
||||
@ -33,13 +32,14 @@ def is_forbidden_import(line: str) -> bool:
|
||||
line = line.strip()
|
||||
return bool(
|
||||
FORBIDDEN_PATTERNS.match(line)
|
||||
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS))
|
||||
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
|
||||
)
|
||||
|
||||
|
||||
def check_file(filepath: str) -> list[tuple[int, str]]:
|
||||
violations = []
|
||||
try:
|
||||
with open(filepath, encoding='utf-8') as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if is_forbidden_import(line):
|
||||
violations.append((line_num, line.strip()))
|
||||
@ -72,9 +72,7 @@ def main() -> int:
|
||||
if total_violations > 0:
|
||||
print(f"\n💡 Found {total_violations} violation(s).")
|
||||
print("❌ Please replace 'import re' with 'import regex as re'")
|
||||
print(
|
||||
" Also replace 'from re import ...' with 'from regex import ...'"
|
||||
) # noqa: E501
|
||||
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
|
||||
print("✅ Allowed imports:")
|
||||
print(" - import regex as re")
|
||||
print(" - import regex") # noqa: E501
|
||||
|
||||
@ -12,8 +12,7 @@ try:
|
||||
# most reliable source of truth for vLLM's build.
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
except ImportError:
|
||||
print("Warning: PyTorch not found. "
|
||||
"Falling back to CUDA_HOME environment variable.")
|
||||
print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.")
|
||||
CUDA_HOME = os.environ.get("CUDA_HOME")
|
||||
|
||||
|
||||
@ -27,8 +26,7 @@ def get_cpu_cores():
|
||||
return multiprocessing.cpu_count()
|
||||
|
||||
|
||||
def generate_presets(output_path="CMakeUserPresets.json",
|
||||
force_overwrite=False):
|
||||
def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False):
|
||||
"""Generates the CMakeUserPresets.json file."""
|
||||
|
||||
print("Attempting to detect your system configuration...")
|
||||
@ -39,8 +37,7 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc")
|
||||
if os.path.exists(prospective_path):
|
||||
nvcc_path = prospective_path
|
||||
print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: "
|
||||
f"{nvcc_path}")
|
||||
print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}")
|
||||
|
||||
if not nvcc_path:
|
||||
nvcc_path = which("nvcc")
|
||||
@ -50,7 +47,8 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
if not nvcc_path:
|
||||
nvcc_path_input = input(
|
||||
"Could not automatically find 'nvcc'. Please provide the full "
|
||||
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ")
|
||||
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): "
|
||||
)
|
||||
nvcc_path = nvcc_path_input.strip()
|
||||
print(f"Using NVCC path: {nvcc_path}")
|
||||
|
||||
@ -63,12 +61,13 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
"Could not automatically find Python executable. Please provide "
|
||||
"the full path to your Python executable for vLLM development "
|
||||
"(typically from your virtual environment, e.g., "
|
||||
"/home/user/venvs/vllm/bin/python): ")
|
||||
"/home/user/venvs/vllm/bin/python): "
|
||||
)
|
||||
python_executable = input(python_executable_prompt).strip()
|
||||
if not python_executable:
|
||||
raise ValueError(
|
||||
"Could not determine Python executable. Please provide it "
|
||||
"manually.")
|
||||
"Could not determine Python executable. Please provide it manually."
|
||||
)
|
||||
|
||||
print(f"Using Python executable: {python_executable}")
|
||||
|
||||
@ -76,20 +75,23 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
cpu_cores = get_cpu_cores()
|
||||
nvcc_threads = min(4, cpu_cores)
|
||||
cmake_jobs = max(1, cpu_cores // nvcc_threads)
|
||||
print(f"Detected {cpu_cores} CPU cores. "
|
||||
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.")
|
||||
print(
|
||||
f"Detected {cpu_cores} CPU cores. "
|
||||
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}."
|
||||
)
|
||||
|
||||
# Get vLLM project root (assuming this script is in vllm/tools/)
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), ".."))
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
print(f"VLLM project root detected as: {project_root}")
|
||||
|
||||
# Ensure python_executable path is absolute or resolvable
|
||||
if not os.path.isabs(python_executable) and which(python_executable):
|
||||
python_executable = os.path.abspath(which(python_executable))
|
||||
elif not os.path.isabs(python_executable):
|
||||
print(f"Warning: Python executable '{python_executable}' is not an "
|
||||
"absolute path and not found in PATH. CMake might not find it.")
|
||||
print(
|
||||
f"Warning: Python executable '{python_executable}' is not an "
|
||||
"absolute path and not found in PATH. CMake might not find it."
|
||||
)
|
||||
|
||||
cache_variables = {
|
||||
"CMAKE_CUDA_COMPILER": nvcc_path,
|
||||
@ -122,24 +124,20 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
configure_preset["generator"] = "Ninja"
|
||||
cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}"
|
||||
else:
|
||||
print("Ninja not found, using default generator. "
|
||||
"Build may be slower.")
|
||||
print("Ninja not found, using default generator. Build may be slower.")
|
||||
|
||||
presets = {
|
||||
"version":
|
||||
6,
|
||||
"version": 6,
|
||||
# Keep in sync with CMakeLists.txt and requirements/build.txt
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 26,
|
||||
"patch": 1
|
||||
},
|
||||
"cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1},
|
||||
"configurePresets": [configure_preset],
|
||||
"buildPresets": [{
|
||||
"name": "release",
|
||||
"configurePreset": "release",
|
||||
"jobs": cmake_jobs,
|
||||
}],
|
||||
"buildPresets": [
|
||||
{
|
||||
"name": "release",
|
||||
"configurePreset": "release",
|
||||
"jobs": cmake_jobs,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
output_file_path = os.path.join(project_root, output_path)
|
||||
@ -148,10 +146,12 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
if force_overwrite:
|
||||
print(f"Overwriting existing file '{output_file_path}'")
|
||||
else:
|
||||
overwrite = input(
|
||||
f"'{output_file_path}' already exists. Overwrite? (y/N): "
|
||||
).strip().lower()
|
||||
if overwrite != 'y':
|
||||
overwrite = (
|
||||
input(f"'{output_file_path}' already exists. Overwrite? (y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if overwrite != "y":
|
||||
print("Generation cancelled.")
|
||||
return
|
||||
|
||||
@ -160,11 +160,9 @@ def generate_presets(output_path="CMakeUserPresets.json",
|
||||
json.dump(presets, f, indent=4)
|
||||
print(f"Successfully generated '{output_file_path}'")
|
||||
print("\nTo use this preset:")
|
||||
print(
|
||||
f"1. Ensure you are in the vLLM root directory: cd {project_root}")
|
||||
print(f"1. Ensure you are in the vLLM root directory: cd {project_root}")
|
||||
print("2. Initialize CMake: cmake --preset release")
|
||||
print("3. Build+install: cmake --build --preset release "
|
||||
"--target install")
|
||||
print("3. Build+install: cmake --build --preset release --target install")
|
||||
|
||||
except OSError as e:
|
||||
print(f"Error writing file: {e}")
|
||||
@ -175,7 +173,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--force-overwrite",
|
||||
action="store_true",
|
||||
help="Force overwrite existing CMakeUserPresets.json without prompting"
|
||||
help="Force overwrite existing CMakeUserPresets.json without prompting",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -17,44 +17,48 @@ import regex as re
|
||||
# add to this list if absolutely necessary and after careful security review.
|
||||
ALLOWED_FILES = {
|
||||
# pickle
|
||||
'vllm/v1/serial_utils.py',
|
||||
'vllm/v1/executor/multiproc_executor.py',
|
||||
'vllm/multimodal/hasher.py',
|
||||
'vllm/transformers_utils/config.py',
|
||||
'vllm/model_executor/models/registry.py',
|
||||
'tests/utils_/test_utils.py',
|
||||
'tests/tokenization/test_cached_tokenizer.py',
|
||||
'vllm/distributed/utils.py',
|
||||
'vllm/distributed/parallel_state.py',
|
||||
'vllm/distributed/device_communicators/all_reduce_utils.py',
|
||||
'vllm/distributed/device_communicators/shm_broadcast.py',
|
||||
'vllm/distributed/device_communicators/shm_object_storage.py',
|
||||
'benchmarks/kernels/graph_machete_bench.py',
|
||||
'benchmarks/kernels/benchmark_lora.py',
|
||||
'benchmarks/kernels/benchmark_machete.py',
|
||||
'benchmarks/fused_kernels/layernorm_rms_benchmarks.py',
|
||||
'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py',
|
||||
'benchmarks/cutlass_benchmarks/sparse_benchmarks.py',
|
||||
"vllm/v1/serial_utils.py",
|
||||
"vllm/v1/executor/multiproc_executor.py",
|
||||
"vllm/multimodal/hasher.py",
|
||||
"vllm/transformers_utils/config.py",
|
||||
"vllm/model_executor/models/registry.py",
|
||||
"tests/utils_/test_utils.py",
|
||||
"tests/tokenization/test_cached_tokenizer.py",
|
||||
"vllm/distributed/utils.py",
|
||||
"vllm/distributed/parallel_state.py",
|
||||
"vllm/distributed/device_communicators/all_reduce_utils.py",
|
||||
"vllm/distributed/device_communicators/shm_broadcast.py",
|
||||
"vllm/distributed/device_communicators/shm_object_storage.py",
|
||||
"benchmarks/kernels/graph_machete_bench.py",
|
||||
"benchmarks/kernels/benchmark_lora.py",
|
||||
"benchmarks/kernels/benchmark_machete.py",
|
||||
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
|
||||
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
|
||||
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
|
||||
# cloudpickle
|
||||
'vllm/executor/mp_distributed_executor.py',
|
||||
'vllm/executor/ray_distributed_executor.py',
|
||||
'vllm/entrypoints/llm.py',
|
||||
'tests/utils.py',
|
||||
"vllm/executor/mp_distributed_executor.py",
|
||||
"vllm/executor/ray_distributed_executor.py",
|
||||
"vllm/entrypoints/llm.py",
|
||||
"tests/utils.py",
|
||||
# pickle and cloudpickle
|
||||
'vllm/utils/__init__.py',
|
||||
"vllm/utils/__init__.py",
|
||||
}
|
||||
|
||||
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
||||
r"|from\s+(pickle|cloudpickle)\s+import\b)")
|
||||
PICKLE_RE = re.compile(
|
||||
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
||||
r"|from\s+(pickle|cloudpickle)\s+import\b)"
|
||||
)
|
||||
|
||||
|
||||
def scan_file(path: str) -> int:
|
||||
with open(path, encoding='utf-8') as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if PICKLE_RE.match(line):
|
||||
print(f"{path}:{i}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found pickle/cloudpickle import")
|
||||
print(
|
||||
f"{path}:{i}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found pickle/cloudpickle import"
|
||||
)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@ -92,13 +96,13 @@ def test_regex():
|
||||
for i, (line, should_match) in enumerate(test_cases):
|
||||
result = bool(PICKLE_RE.match(line))
|
||||
assert result == should_match, (
|
||||
f"Test case {i} failed: '{line}' "
|
||||
f"(expected {should_match}, got {result})")
|
||||
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
|
||||
)
|
||||
print("All regex tests passed.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if '--test-regex' in sys.argv:
|
||||
if __name__ == "__main__":
|
||||
if "--test-regex" in sys.argv:
|
||||
test_regex()
|
||||
else:
|
||||
sys.exit(main())
|
||||
|
||||
@ -94,11 +94,15 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
return file_groups
|
||||
|
||||
|
||||
def mypy(targets: list[str], python_version: Optional[str],
|
||||
follow_imports: Optional[str], file_group: str) -> int:
|
||||
def mypy(
|
||||
targets: list[str],
|
||||
python_version: Optional[str],
|
||||
follow_imports: Optional[str],
|
||||
file_group: str,
|
||||
) -> int:
|
||||
"""
|
||||
Run mypy on the given targets.
|
||||
|
||||
|
||||
Args:
|
||||
targets: List of files or directories to check.
|
||||
python_version: Python version to use (e.g., "3.10") or None to use
|
||||
@ -131,8 +135,9 @@ def main():
|
||||
for file_group, changed_files in file_groups.items():
|
||||
follow_imports = None if ci and file_group == "" else "skip"
|
||||
if changed_files:
|
||||
returncode |= mypy(changed_files, python_version, follow_imports,
|
||||
file_group)
|
||||
returncode |= mypy(
|
||||
changed_files, python_version, follow_imports, file_group
|
||||
)
|
||||
return returncode
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This generates gpu kernel analysis output from nsys rep. Will call nsys
|
||||
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
|
||||
csv and html output for analysis
|
||||
This generates gpu kernel analysis output from nsys rep. Will call nsys
|
||||
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
|
||||
csv and html output for analysis
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -16,13 +17,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# helper data class for annotating kernels
|
||||
def load_engine_model():
|
||||
""" returns engine_model built from all json files in the current dir """
|
||||
"""returns engine_model built from all json files in the current dir"""
|
||||
import glob
|
||||
import json
|
||||
|
||||
engine_model = {}
|
||||
|
||||
json_files = glob.glob(
|
||||
os.path.join(os.path.dirname(__file__) or ".", "*.json"))
|
||||
json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json"))
|
||||
for fname in json_files:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
engine_model.update(json.load(f))
|
||||
@ -30,54 +31,54 @@ def load_engine_model():
|
||||
|
||||
|
||||
class GPUTrace2Graph:
|
||||
"""
|
||||
Parses output of nsys report, generates csv and bar chart output
|
||||
"""
|
||||
Parses output of nsys report, generates csv and bar chart output
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
import pandas as pd # avoid importing till needed
|
||||
|
||||
self.pd = pd
|
||||
self.pd.options.mode.copy_on_write = True
|
||||
|
||||
# helper functions for generating trace->summary csvs
|
||||
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
|
||||
logger.info('loading %s', in_file)
|
||||
logger.info("loading %s", in_file)
|
||||
df = self.pd.read_csv(
|
||||
in_file,
|
||||
usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name'])
|
||||
df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)']
|
||||
in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"]
|
||||
)
|
||||
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
|
||||
df = self.sum_non_overlapping_intervals(df)
|
||||
# get ready to print table with elapsed times per kernel
|
||||
df['Instances'] = 1
|
||||
df_sum = df.groupby('Name', as_index=False).agg({
|
||||
'Elapsed Time (ns)': 'sum',
|
||||
'Duration (ns)': 'sum',
|
||||
'Instances': 'size'
|
||||
})
|
||||
df["Instances"] = 1
|
||||
df_sum = df.groupby("Name", as_index=False).agg(
|
||||
{"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"}
|
||||
)
|
||||
|
||||
# generate csv
|
||||
df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9
|
||||
df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9
|
||||
df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False)
|
||||
df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances',
|
||||
'Name']].to_csv(out_file, index=False)
|
||||
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
|
||||
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
|
||||
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
|
||||
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv(
|
||||
out_file, index=False
|
||||
)
|
||||
|
||||
def sum_non_overlapping_intervals(self, df):
|
||||
"""
|
||||
returns new sorted df with Elapsed Time (ns) column using
|
||||
vectorized operations
|
||||
"""
|
||||
returns new sorted df with Elapsed Time (ns) column using
|
||||
vectorized operations
|
||||
"""
|
||||
logger.info("sorting %s trace records by start time", str(df.shape))
|
||||
|
||||
# Sort by start time and reset index
|
||||
df = df.sort_values(by='Start (ns)').reset_index(drop=True)
|
||||
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
|
||||
|
||||
# Initialize elapsed time as duration
|
||||
df['Elapsed Time (ns)'] = df['Duration (ns)']
|
||||
df["Elapsed Time (ns)"] = df["Duration (ns)"]
|
||||
|
||||
# Get numpy arrays for faster operations
|
||||
starts = df['Start (ns)'].values
|
||||
ends = df['End (ns)'].values
|
||||
starts = df["Start (ns)"].values
|
||||
ends = df["End (ns)"].values
|
||||
|
||||
# Keep track of current interval end
|
||||
current_end = ends[0]
|
||||
@ -85,16 +86,17 @@ class GPUTrace2Graph:
|
||||
# Update current_end for overlapping intervals
|
||||
for i in range(1, len(df)):
|
||||
if i % display_units == 0:
|
||||
print(f'processing trace: {int(i/len(df) * 100)} %', end="\r")
|
||||
print(f"processing trace: {int(i / len(df) * 100)} %", end="\r")
|
||||
if starts[i] <= current_end:
|
||||
if ends[i] > current_end:
|
||||
# Partial overlap
|
||||
df.iloc[i, df.columns.get_loc('Elapsed Time (ns)'
|
||||
)] = ends[i] - current_end
|
||||
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
|
||||
ends[i] - current_end
|
||||
)
|
||||
current_end = ends[i]
|
||||
else:
|
||||
# Complete overlap
|
||||
df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0
|
||||
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
|
||||
else:
|
||||
# No overlap
|
||||
current_end = ends[i]
|
||||
@ -103,147 +105,167 @@ class GPUTrace2Graph:
|
||||
|
||||
# functions for generating html files
|
||||
def make_html(self, df, output_dir, title):
|
||||
""" make html graph from df """
|
||||
"""make html graph from df"""
|
||||
import plotly.express as px
|
||||
|
||||
if df.empty:
|
||||
return
|
||||
output_name = output_dir + '/result'
|
||||
output_name = output_dir + "/result"
|
||||
if not title:
|
||||
title = 'Model_Engine'
|
||||
x = 'Model_Engine'
|
||||
y = 'Elapsed Time (sec)'
|
||||
color = 'Category'
|
||||
title = "Model_Engine"
|
||||
x = "Model_Engine"
|
||||
y = "Elapsed Time (sec)"
|
||||
color = "Category"
|
||||
""" generate kernel mapping table """
|
||||
# Sort Model_Engine categories by last field after underscore
|
||||
df['Model_Engine'] = self.pd.Categorical(
|
||||
df['Model_Engine'],
|
||||
sorted(df['Model_Engine'].unique(),
|
||||
key=lambda x: x.split('_')[-1]))
|
||||
df[['Model_Engine', color, 'Instances', 'Name',
|
||||
y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False)
|
||||
graph = px.histogram(df.round(2),
|
||||
x=x,
|
||||
y=y,
|
||||
title=(f'{y} for {title}'),
|
||||
color=color,
|
||||
text_auto=True)
|
||||
df["Model_Engine"] = self.pd.Categorical(
|
||||
df["Model_Engine"],
|
||||
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
|
||||
)
|
||||
df[["Model_Engine", color, "Instances", "Name", y]].sort_values(
|
||||
by=color
|
||||
).to_csv(f"{output_name}.csv", index=False)
|
||||
graph = px.histogram(
|
||||
df.round(2),
|
||||
x=x,
|
||||
y=y,
|
||||
title=(f"{y} for {title}"),
|
||||
color=color,
|
||||
text_auto=True,
|
||||
)
|
||||
# wrap x axis labels
|
||||
graph.update_xaxes(automargin=True)
|
||||
graph.write_html(f'{output_name}.html')
|
||||
graph.write_html(f"{output_name}.html")
|
||||
"""
|
||||
Generate data table with columns per Model_Engine into result.html
|
||||
"""
|
||||
pivot_df = df.pivot_table(values='Elapsed Time (sec)',
|
||||
index='Category',
|
||||
columns='Model_Engine',
|
||||
aggfunc='sum',
|
||||
observed=False).round(2)
|
||||
pivot_df = df.pivot_table(
|
||||
values="Elapsed Time (sec)",
|
||||
index="Category",
|
||||
columns="Model_Engine",
|
||||
aggfunc="sum",
|
||||
observed=False,
|
||||
).round(2)
|
||||
# Add sum row at bottom
|
||||
pivot_df.loc['total_elapsed_sec'] = pivot_df.sum()
|
||||
pivot_df.fillna('').to_html('temp.html')
|
||||
with (open(f'{output_name}.html', 'a', encoding='utf-8') as
|
||||
outfile, open('temp.html', encoding='utf-8') as infile):
|
||||
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
|
||||
pivot_df.fillna("").to_html("temp.html")
|
||||
with (
|
||||
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
|
||||
open("temp.html", encoding="utf-8") as infile,
|
||||
):
|
||||
outfile.write(infile.read())
|
||||
os.remove('temp.html')
|
||||
os.remove("temp.html")
|
||||
|
||||
print(f'Finished generating: \n'
|
||||
f' {output_name}.html for stack bar chart \n'
|
||||
f' {output_name}.csv for Kernel-Category mapping')
|
||||
print(
|
||||
f"Finished generating: \n"
|
||||
f" {output_name}.html for stack bar chart \n"
|
||||
f" {output_name}.csv for Kernel-Category mapping"
|
||||
)
|
||||
|
||||
def anno_gpu_kernname(self, df, mapping):
|
||||
""" add "Category" column """
|
||||
"""add "Category" column"""
|
||||
|
||||
def anno_gpu_kernname_helper(name):
|
||||
for kern_name, val in mapping.items():
|
||||
if re.search(kern_name, name):
|
||||
return val
|
||||
|
||||
df['Category'] = df['Name'].apply(anno_gpu_kernname_helper)
|
||||
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
|
||||
|
||||
def make_nongpu_row(self, df, nongpu_sec):
|
||||
""" this will append non-gpu time entry at end of df """
|
||||
"""this will append non-gpu time entry at end of df"""
|
||||
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
|
||||
nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)'
|
||||
nongpu_row['Instances'] = 1
|
||||
nongpu_row['Elapsed Time (sec)'] = nongpu_sec
|
||||
return (nongpu_row)
|
||||
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
|
||||
nongpu_row["Instances"] = 1
|
||||
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
|
||||
return nongpu_row
|
||||
|
||||
def is_valid_file(self, base_file):
|
||||
""" asserts if base_file is non-existent or is empty """
|
||||
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \
|
||||
f"{base_file} doesn't exist or is empty"
|
||||
"""asserts if base_file is non-existent or is empty"""
|
||||
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, (
|
||||
f"{base_file} doesn't exist or is empty"
|
||||
)
|
||||
|
||||
def should_gen_file(self, new_file, base_file):
|
||||
""" figure out if new file should be generated from base_file """
|
||||
"""figure out if new file should be generated from base_file"""
|
||||
self.is_valid_file(base_file)
|
||||
if (os.path.exists(new_file)
|
||||
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
|
||||
and (os.path.getsize(base_file) > 0)):
|
||||
logger.info('reusing %s', new_file)
|
||||
if (
|
||||
os.path.exists(new_file)
|
||||
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
|
||||
and (os.path.getsize(base_file) > 0)
|
||||
):
|
||||
logger.info("reusing %s", new_file)
|
||||
return False
|
||||
else:
|
||||
logger.info('generating %s', new_file)
|
||||
logger.info("generating %s", new_file)
|
||||
return True
|
||||
|
||||
def gen_sum_file(self, file, nsys_cmd):
|
||||
"""
|
||||
generates sum file from nsys trace with times per kernel and
|
||||
returns the name of the sum file
|
||||
"""
|
||||
generates sum file from nsys trace with times per kernel and
|
||||
returns the name of the sum file
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
file_dir = os.path.dirname(file)
|
||||
file_name = os.path.basename(file)
|
||||
|
||||
if not file_dir:
|
||||
file_dir = '.'
|
||||
file_dir = "."
|
||||
# Walk through trace and get the total non-overlapped time
|
||||
nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv'
|
||||
sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv'
|
||||
nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv"
|
||||
sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv"
|
||||
if self.should_gen_file(nsys_stats_file, file):
|
||||
cmd = [
|
||||
nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o',
|
||||
f'{file_dir}/{file_name}'
|
||||
nsys_cmd,
|
||||
"stats",
|
||||
"-r",
|
||||
"cuda_gpu_trace",
|
||||
file,
|
||||
"-o",
|
||||
f"{file_dir}/{file_name}",
|
||||
]
|
||||
cmd_str = ' '.join(cmd)
|
||||
logger.info('+ %s', cmd_str)
|
||||
cmd_str = " ".join(cmd)
|
||||
logger.info("+ %s", cmd_str)
|
||||
# estimate time based on calibrated 240M/min
|
||||
file_size_mb = os.path.getsize(file) / 1e6
|
||||
logger.info(
|
||||
'nsys stats for %.2f MB file expected to take %.2f min',
|
||||
file_size_mb, file_size_mb / 240)
|
||||
"nsys stats for %.2f MB file expected to take %.2f min",
|
||||
file_size_mb,
|
||||
file_size_mb / 240,
|
||||
)
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception:
|
||||
logger.error("%s failed; Use --nsys_cmd to specify nsys path",
|
||||
cmd_str)
|
||||
logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str)
|
||||
exit(1)
|
||||
logger.info('generating non-overalapped sum %s', sum_file)
|
||||
logger.info("generating non-overalapped sum %s", sum_file)
|
||||
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
|
||||
self.is_valid_file(sum_file)
|
||||
logger.info('Finished generating %s', sum_file)
|
||||
logger.info("Finished generating %s", sum_file)
|
||||
return sum_file
|
||||
|
||||
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
|
||||
""" generates graph and csv file from in_file into out_dir """
|
||||
"""generates graph and csv file from in_file into out_dir"""
|
||||
# Initialize an empty DataFrame to store combined data
|
||||
combined_df = self.pd.DataFrame()
|
||||
for idx, (file, engine, model, total_sec) in enumerate(in_file):
|
||||
file_dir = os.path.dirname(file)
|
||||
file_name = os.path.basename(file)
|
||||
if not file_dir:
|
||||
file_dir = '.'
|
||||
file_dir = "."
|
||||
sum_file = self.gen_sum_file(file, nsys_cmd)
|
||||
# read kernel summary file
|
||||
df = self.pd.read_csv(sum_file)
|
||||
# annotate kernel to their categories
|
||||
assert engine_model.get(engine), f'engine {engine} unknown'
|
||||
assert engine_model[engine].get(model), f'model {model} unknown'
|
||||
assert engine_model.get(engine), f"engine {engine} unknown"
|
||||
assert engine_model[engine].get(model), f"model {model} unknown"
|
||||
# remove nsys-rep from file_name for shorter x-label
|
||||
file_name = file_name.replace('.nsys-rep', '')
|
||||
df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}'
|
||||
file_name = file_name.replace(".nsys-rep", "")
|
||||
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
|
||||
self.anno_gpu_kernname(df, engine_model[engine][model])
|
||||
# patch in non-gpu time
|
||||
gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1)
|
||||
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
|
||||
total_sec = round(float(total_sec), 1)
|
||||
if total_sec < gpu_sec:
|
||||
logger.warning(
|
||||
@ -256,7 +278,7 @@ class GPUTrace2Graph:
|
||||
df = self.pd.concat([df, nongpu_row], ignore_index=True)
|
||||
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
|
||||
if out_dir is None:
|
||||
out_dir = '.'
|
||||
out_dir = "."
|
||||
else:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
# generate html file
|
||||
@ -264,50 +286,59 @@ class GPUTrace2Graph:
|
||||
|
||||
|
||||
def parse_tuple(s):
|
||||
return tuple(s.split(','))
|
||||
return tuple(s.split(","))
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'),
|
||||
level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO
|
||||
)
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
'Process nsys rep and generate kernel non-overlapped cycles. \n'
|
||||
'Example:\n'
|
||||
"Process nsys rep and generate kernel non-overlapped cycles. \n"
|
||||
"Example:\n"
|
||||
"gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n"
|
||||
"d2.nsys-rep,vllm,gpt-oss,102 "
|
||||
"--out_dir results/ --title \"Model=gpt-oss vLLM chart\""),
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
'--out_dir results/ --title "Model=gpt-oss vLLM chart"'
|
||||
),
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
# load supported engine_model
|
||||
engine_model_supported = load_engine_model()
|
||||
# Get a string representation of supported engine/model combinations
|
||||
engine_model_supported_str = ', '.join(
|
||||
engine_model_supported_str = ", ".join(
|
||||
f"{engine}:[{', '.join(models.keys())}]"
|
||||
for engine, models in engine_model_supported.items())
|
||||
for engine, models in engine_model_supported.items()
|
||||
)
|
||||
parser.add_argument(
|
||||
'--in_file',
|
||||
"--in_file",
|
||||
type=parse_tuple,
|
||||
nargs='+',
|
||||
nargs="+",
|
||||
help=(
|
||||
'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) '
|
||||
'separated by space. Elapsed_nonprofiled_sec is runtime without '
|
||||
'profiling used to calculate non-gpu time. Specify 0 to use '
|
||||
'elapsed time from nsys-rep but that might inflate non-gpu time. '
|
||||
f'Available engine:[model] are: {engine_model_supported_str} '
|
||||
f'Example: --infile d1.nsys-rep,vllm,llama,100 '
|
||||
'd2.nsys-rep,vllm,gpt-oss,102'),
|
||||
required=True)
|
||||
parser.add_argument('--out_dir', help=('output dir for result.csv/html'))
|
||||
parser.add_argument('--title', help=('title for html chart'))
|
||||
parser.add_argument('--nsys_cmd',
|
||||
help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'),
|
||||
default="nsys")
|
||||
"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
|
||||
"separated by space. Elapsed_nonprofiled_sec is runtime without "
|
||||
"profiling used to calculate non-gpu time. Specify 0 to use "
|
||||
"elapsed time from nsys-rep but that might inflate non-gpu time. "
|
||||
f"Available engine:[model] are: {engine_model_supported_str} "
|
||||
f"Example: --infile d1.nsys-rep,vllm,llama,100 "
|
||||
"d2.nsys-rep,vllm,gpt-oss,102"
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
|
||||
parser.add_argument("--title", help=("title for html chart"))
|
||||
parser.add_argument(
|
||||
"--nsys_cmd",
|
||||
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
|
||||
default="nsys",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
gputrace = GPUTrace2Graph()
|
||||
gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd,
|
||||
engine_model_supported)
|
||||
gputrace.gen_graph(
|
||||
args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -29,48 +29,50 @@ def flatten_entries(entry_cls, profile_dict: dict):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by "
|
||||
"examples/offline_inference/profiling.py")
|
||||
parser.add_argument("--phase",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The phase to print the table for. This is either"
|
||||
"prefill or decode_n, where n is the decode step "
|
||||
"number")
|
||||
parser.add_argument("--table",
|
||||
type=str,
|
||||
choices=["summary", "model"],
|
||||
default="summary",
|
||||
help="Which table to print, the summary table or the "
|
||||
"layerwise model table")
|
||||
parser.add_argument(
|
||||
"--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by examples/offline_inference/profiling.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The phase to print the table for. This is either"
|
||||
"prefill or decode_n, where n is the decode step "
|
||||
"number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--table",
|
||||
type=str,
|
||||
choices=["summary", "model"],
|
||||
default="summary",
|
||||
help="Which table to print, the summary table or the layerwise model table",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.json_trace) as f:
|
||||
profile_data = json.load(f)
|
||||
|
||||
assert args.phase in profile_data, \
|
||||
(f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
|
||||
assert args.phase in profile_data, (
|
||||
f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||
f"{[x for x in profile_data.keys() if 'prefill' in x or 'decode' in x]}"
|
||||
) # noqa
|
||||
|
||||
if args.table == "summary":
|
||||
entries_and_depths = flatten_entries(
|
||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
|
||||
column_widths = dict(name=80,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
invocations=15)
|
||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]
|
||||
)
|
||||
column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15)
|
||||
elif args.table == "model":
|
||||
entries_and_depths = flatten_entries(
|
||||
ModelStatsEntry, profile_data[args.phase]["model_stats"])
|
||||
column_widths = dict(name=60,
|
||||
cpu_time_us=12,
|
||||
cuda_time_us=12,
|
||||
pct_cuda_time=12,
|
||||
trace=60)
|
||||
ModelStatsEntry, profile_data[args.phase]["model_stats"]
|
||||
)
|
||||
column_widths = dict(
|
||||
name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
|
||||
)
|
||||
|
||||
# indent entry names based on the depth
|
||||
entries = []
|
||||
@ -78,7 +80,8 @@ if __name__ == "__main__":
|
||||
entry.name = indent_string(
|
||||
entry.name,
|
||||
indent=depth,
|
||||
indent_style=lambda indent: "|" + "-" * indent + " ")
|
||||
indent_style=lambda indent: "|" + "-" * indent + " ",
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
TablePrinter(type(entries[0]), column_widths).print_table(entries)
|
||||
|
||||
@ -18,17 +18,18 @@ import pandas as pd
|
||||
def largest_dist_from_leaf(node: dict, depth: int = 0):
|
||||
if len(node["children"]) == 0:
|
||||
return depth
|
||||
return max([
|
||||
largest_dist_from_leaf(child, depth=depth + 1)
|
||||
for child in node["children"]
|
||||
])
|
||||
return max(
|
||||
[largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]]
|
||||
)
|
||||
|
||||
|
||||
def get_entries_at_depth(depth: int,
|
||||
entries_and_traces: list[tuple[Any, Any]],
|
||||
node: dict,
|
||||
curr_depth: int = 0,
|
||||
trace=()):
|
||||
def get_entries_at_depth(
|
||||
depth: int,
|
||||
entries_and_traces: list[tuple[Any, Any]],
|
||||
node: dict,
|
||||
curr_depth: int = 0,
|
||||
trace=(),
|
||||
):
|
||||
# assert that the query is at kernel or module level
|
||||
assert depth == -1 or depth == -2
|
||||
|
||||
@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int,
|
||||
if largest_dist_from_leaf(node) == (abs(depth) - 1):
|
||||
entries_and_traces.append((node["entry"], trace))
|
||||
|
||||
trace = (node["entry"]["name"], ) + trace
|
||||
trace = (node["entry"]["name"],) + trace
|
||||
for child in node["children"]:
|
||||
get_entries_at_depth(depth,
|
||||
entries_and_traces,
|
||||
child,
|
||||
curr_depth=curr_depth + 1,
|
||||
trace=trace)
|
||||
get_entries_at_depth(
|
||||
depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace
|
||||
)
|
||||
|
||||
|
||||
def fold_nodes(root: dict, nodes_to_fold: list[str]):
|
||||
|
||||
stack: list[dict] = [root]
|
||||
while len(stack) != 0:
|
||||
node = stack.pop()
|
||||
if node['entry']['name'] in nodes_to_fold:
|
||||
if node["entry"]["name"] in nodes_to_fold:
|
||||
node["children"] = []
|
||||
continue
|
||||
for child in node["children"]:
|
||||
@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str:
|
||||
|
||||
def shorten_plot_legend_strings(legend, max_char_len: int):
|
||||
for t in legend.get_texts():
|
||||
t.set_text(
|
||||
trim_string_back(abbreviate_known_names(t.get_text()),
|
||||
max_char_len))
|
||||
t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len))
|
||||
|
||||
|
||||
def abbreviate_known_names(name: str) -> str:
|
||||
@ -108,15 +104,21 @@ def attempt_to_make_names_unique(entries_and_traces):
|
||||
names.add(entry["name"])
|
||||
|
||||
for name in non_unique_names:
|
||||
entries_and_traces_with_name = [(entry, trace)
|
||||
for entry, trace in entries_and_traces
|
||||
if entry["name"] == name]
|
||||
entries_and_traces_with_name = [
|
||||
(entry, trace)
|
||||
for entry, trace in entries_and_traces
|
||||
if entry["name"] == name
|
||||
]
|
||||
|
||||
zipped_traces = list(
|
||||
zip(*[trace for _, trace in entries_and_traces_with_name]))
|
||||
zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name]))
|
||||
first_trace_difference = next(
|
||||
(i for i, trace_eles in enumerate(zipped_traces)
|
||||
if not all_the_same(trace_eles)), None)
|
||||
(
|
||||
i
|
||||
for i, trace_eles in enumerate(zipped_traces)
|
||||
if not all_the_same(trace_eles)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if first_trace_difference is None:
|
||||
# can't create a unique name, leave the names as they
|
||||
@ -124,34 +126,32 @@ def attempt_to_make_names_unique(entries_and_traces):
|
||||
continue
|
||||
|
||||
for entry, trace in entries_and_traces_with_name:
|
||||
entry["name"] = " <- ".join((entry["name"], ) +
|
||||
trace[:first_trace_difference + 1])
|
||||
entry["name"] = " <- ".join(
|
||||
(entry["name"],) + trace[: first_trace_difference + 1]
|
||||
)
|
||||
|
||||
|
||||
## Operation grouping utils ####
|
||||
'''
|
||||
"""
|
||||
Group operations in the given dataframe by some high-level ops like,
|
||||
- gemms
|
||||
- attention
|
||||
- rms_norm
|
||||
etc.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
def is_rms_norm(op_name: str):
|
||||
if "rms_norm_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_attention_block(op_name: str):
|
||||
if "flash_fwd" in op_name or \
|
||||
"reshape_and_cache_flash_kernel" in op_name:
|
||||
if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name:
|
||||
return True
|
||||
|
||||
def is_quant(op_name: str):
|
||||
if "scaled_fp8_quant" in op_name or \
|
||||
"scaled_int8_quant" in op_name:
|
||||
if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name:
|
||||
return True
|
||||
|
||||
# LoRA ops
|
||||
@ -168,24 +168,27 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
return "bgmv_expand" in op_name
|
||||
|
||||
def is_cutlass_gemm_op(op_name: str):
|
||||
return "void cutlass::Kernel" in op_name or \
|
||||
"void cutlass::device_kernel" in op_name
|
||||
return (
|
||||
"void cutlass::Kernel" in op_name
|
||||
or "void cutlass::device_kernel" in op_name
|
||||
)
|
||||
|
||||
def is_gemm_op(op_name: str):
|
||||
if is_quant(op_name):
|
||||
return False
|
||||
return is_cutlass_gemm_op(op_name) or \
|
||||
"xmma_gemm" in op_name or \
|
||||
"gemv2T_kernel" in op_name or \
|
||||
"splitKreduce" in op_name or \
|
||||
"s16816gemm" in op_name
|
||||
return (
|
||||
is_cutlass_gemm_op(op_name)
|
||||
or "xmma_gemm" in op_name
|
||||
or "gemv2T_kernel" in op_name
|
||||
or "splitKreduce" in op_name
|
||||
or "s16816gemm" in op_name
|
||||
)
|
||||
|
||||
def is_elementwise_op(op_name: str):
|
||||
return "elementwise_kernel" in op_name
|
||||
|
||||
def is_mem_op(op_name: str):
|
||||
return "memcpy" in op_name.lower() or \
|
||||
"memset" in op_name.lower()
|
||||
return "memcpy" in op_name.lower() or "memset" in op_name.lower()
|
||||
|
||||
def is_vocab_embedding_op(op_name: str):
|
||||
return "vocabparallelembed" in op_name.lower()
|
||||
@ -195,17 +198,15 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
return "nccl" in op_name.lower()
|
||||
|
||||
def is_nccl_all_reduce(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
("all_reduce" in op_name.lower() or \
|
||||
"allreduce" in op_name.lower())
|
||||
return is_nccl_op(op_name) and (
|
||||
"all_reduce" in op_name.lower() or "allreduce" in op_name.lower()
|
||||
)
|
||||
|
||||
def is_nccl_gather(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
"gather" in op_name.lower()
|
||||
return is_nccl_op(op_name) and "gather" in op_name.lower()
|
||||
|
||||
def is_nccl_broadcast(op_name: str):
|
||||
return is_nccl_op(op_name) and \
|
||||
"broadcast" in op_name.lower()
|
||||
return is_nccl_op(op_name) and "broadcast" in op_name.lower()
|
||||
|
||||
# Reduce ops types
|
||||
def is_cross_device_reduce_1stage(op_name: str):
|
||||
@ -269,114 +270,122 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
|
||||
|
||||
cross_device_reduce_1stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
|
||||
filter(lambda x: is_cross_device_reduce_1stage(x), ops)
|
||||
)
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
|
||||
|
||||
cross_device_reduce_2stage_ops = list(
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops)
|
||||
)
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
|
||||
|
||||
custom_ar_all_reduce_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce(x), ops))
|
||||
custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
|
||||
|
||||
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
|
||||
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
|
||||
|
||||
if len(attention_ops):
|
||||
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
|
||||
trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1)
|
||||
if len(quant_ops):
|
||||
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
|
||||
trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1)
|
||||
|
||||
if len(sgmv_shrink_ops):
|
||||
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1)
|
||||
if len(sgmv_expand_ops):
|
||||
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1)
|
||||
if len(bgmv_shrink_ops):
|
||||
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1)
|
||||
if len(bgmv_expand_ops):
|
||||
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1)
|
||||
|
||||
if len(cutlass_gemm_ops):
|
||||
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1)
|
||||
|
||||
if len(gemm_ops):
|
||||
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||
trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||
if len(rms_norm_ops):
|
||||
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
|
||||
trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1)
|
||||
if len(vocab_embed_ops):
|
||||
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1)
|
||||
if len(mem_ops):
|
||||
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
|
||||
trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1)
|
||||
if len(elementwise_ops):
|
||||
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1)
|
||||
|
||||
if len(nccl_all_reduce_ops):
|
||||
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
|
||||
"sum", axis=1)
|
||||
trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg(
|
||||
"sum", axis=1
|
||||
)
|
||||
if len(nccl_gather_ops):
|
||||
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1)
|
||||
if len(nccl_broadcast_ops):
|
||||
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
|
||||
"sum", axis=1)
|
||||
trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1)
|
||||
if len(nccl_other_ops):
|
||||
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1)
|
||||
|
||||
if len(cross_device_reduce_1stage_ops):
|
||||
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
|
||||
cross_device_reduce_1stage_ops].agg("sum", axis=1)
|
||||
trace_df["cross_device_reduce_1stage_ops"] = trace_df[
|
||||
cross_device_reduce_1stage_ops
|
||||
].agg("sum", axis=1)
|
||||
if len(cross_device_reduce_2stage_ops):
|
||||
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
|
||||
cross_device_reduce_2stage_ops].agg("sum", axis=1)
|
||||
trace_df["cross_device_reduce_2stage_ops"] = trace_df[
|
||||
cross_device_reduce_2stage_ops
|
||||
].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_ops):
|
||||
trace_df['custom_ar_all_reduce_ops'] = trace_df[
|
||||
custom_ar_all_reduce_ops].agg("sum", axis=1)
|
||||
trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg(
|
||||
"sum", axis=1
|
||||
)
|
||||
if len(reduce_kernel_ops):
|
||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||
axis=1)
|
||||
trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1)
|
||||
|
||||
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
|
||||
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
|
||||
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
|
||||
vocab_embed_ops + mem_ops + elementwise_ops +
|
||||
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
||||
nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
trace_df.drop(
|
||||
attention_ops
|
||||
+ quant_ops
|
||||
+ sgmv_shrink_ops
|
||||
+ sgmv_expand_ops
|
||||
+ bgmv_shrink_ops
|
||||
+ bgmv_expand_ops
|
||||
+ cutlass_gemm_ops
|
||||
+ gemm_ops
|
||||
+ rms_norm_ops
|
||||
+ vocab_embed_ops
|
||||
+ mem_ops
|
||||
+ elementwise_ops
|
||||
+ nccl_all_reduce_ops
|
||||
+ nccl_gather_ops
|
||||
+ nccl_broadcast_ops
|
||||
+ nccl_other_ops
|
||||
+ cross_device_reduce_1stage_ops
|
||||
+ cross_device_reduce_2stage_ops
|
||||
+ custom_ar_all_reduce_ops
|
||||
+ reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True,
|
||||
)
|
||||
return trace_df
|
||||
|
||||
|
||||
## Data plotting utils ####
|
||||
|
||||
|
||||
def plot_trace_df(traces_df: pd.DataFrame,
|
||||
plot_metric: str,
|
||||
plot_title: str,
|
||||
output: Optional[Path] = None):
|
||||
|
||||
def plot_trace_df(
|
||||
traces_df: pd.DataFrame,
|
||||
plot_metric: str,
|
||||
plot_title: str,
|
||||
output: Optional[Path] = None,
|
||||
):
|
||||
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
|
||||
phase_df = traces_df.query(f'phase == "{phase}"')
|
||||
descs = phase_df['phase_desc'].to_list()
|
||||
descs = phase_df["phase_desc"].to_list()
|
||||
assert all([desc == descs[0] for desc in descs])
|
||||
return descs[0]
|
||||
|
||||
phases = traces_df['phase'].unique()
|
||||
phases = traces_df["phase"].unique()
|
||||
phase_descs = [get_phase_description(traces_df, p) for p in phases]
|
||||
traces_df = traces_df.pivot_table(index="phase",
|
||||
columns="name",
|
||||
values=plot_metric,
|
||||
aggfunc="sum")
|
||||
traces_df = traces_df.pivot_table(
|
||||
index="phase", columns="name", values=plot_metric, aggfunc="sum"
|
||||
)
|
||||
|
||||
traces_df = group_trace_by_operations(traces_df)
|
||||
|
||||
@ -396,20 +405,19 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
||||
# Write the values as text on the bars
|
||||
for bar in ax.patches:
|
||||
if bar.get_height() != 0:
|
||||
ax.text(bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() / 2 + bar.get_y(),
|
||||
f"{round(bar.get_height(), 2)}",
|
||||
ha='center',
|
||||
color='w',
|
||||
weight='bold',
|
||||
size=5)
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() / 2 + bar.get_y(),
|
||||
f"{round(bar.get_height(), 2)}",
|
||||
ha="center",
|
||||
color="w",
|
||||
weight="bold",
|
||||
size=5,
|
||||
)
|
||||
|
||||
# Setup legend
|
||||
handles, labels = plt.gca().get_legend_handles_labels()
|
||||
legend = fig.legend(handles,
|
||||
labels,
|
||||
loc='center left',
|
||||
bbox_to_anchor=(1, 1))
|
||||
legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1))
|
||||
shorten_plot_legend_strings(legend, 50)
|
||||
|
||||
# Setup labels and title
|
||||
@ -417,21 +425,20 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
||||
ax.set_ylabel(plot_metric)
|
||||
plt.suptitle(plot_title)
|
||||
|
||||
plt.savefig(output, bbox_inches='tight')
|
||||
plt.savefig(output, bbox_inches="tight")
|
||||
print("Created: ", output)
|
||||
|
||||
|
||||
def main(
|
||||
json_trace: Path,
|
||||
output_directory: Path,
|
||||
depth: int, # Fetch/Plot operations at this depth of the Json tree
|
||||
plot_metric: str,
|
||||
make_names_unique: bool,
|
||||
top_k: int,
|
||||
json_nodes_to_fold: list[str]):
|
||||
|
||||
json_trace: Path,
|
||||
output_directory: Path,
|
||||
depth: int, # Fetch/Plot operations at this depth of the Json tree
|
||||
plot_metric: str,
|
||||
make_names_unique: bool,
|
||||
top_k: int,
|
||||
json_nodes_to_fold: list[str],
|
||||
):
|
||||
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
|
||||
|
||||
def get_entries_and_traces(key: str):
|
||||
entries_and_traces: list[tuple[Any, Any]] = []
|
||||
for root in profile_json[key]["summary_stats"]:
|
||||
@ -441,16 +448,14 @@ def main(
|
||||
get_entries_at_depth(depth, entries_and_traces, root)
|
||||
return entries_and_traces
|
||||
|
||||
def keep_only_top_entries(df: pd.DataFrame,
|
||||
metric: str,
|
||||
top_k: int = 9) -> pd.DataFrame:
|
||||
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
|
||||
["name"]] = "others"
|
||||
def keep_only_top_entries(
|
||||
df: pd.DataFrame, metric: str, top_k: int = 9
|
||||
) -> pd.DataFrame:
|
||||
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
|
||||
return df
|
||||
|
||||
def get_phase_description(key: str) -> str:
|
||||
num_running_seqs = profile_json[key]['metadata'][
|
||||
'num_running_seqs']
|
||||
num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"]
|
||||
if num_running_seqs is not None:
|
||||
return f"{key}-seqs-{num_running_seqs}"
|
||||
else:
|
||||
@ -466,20 +471,24 @@ def main(
|
||||
|
||||
# To pandas dataframe
|
||||
trace_dfs = list(
|
||||
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
|
||||
traces))
|
||||
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces)
|
||||
)
|
||||
|
||||
# Respect top_k
|
||||
if top_k:
|
||||
trace_dfs = list(
|
||||
map(
|
||||
lambda trace_df: keep_only_top_entries(
|
||||
trace_df, "cuda_time_us", top_k), trace_dfs))
|
||||
trace_df, "cuda_time_us", top_k
|
||||
),
|
||||
trace_dfs,
|
||||
)
|
||||
)
|
||||
|
||||
# Fill in information about the step-keys
|
||||
for trace_df, step_key in zip(trace_dfs, step_keys):
|
||||
trace_df['phase'] = step_key
|
||||
trace_df['phase_desc'] = get_phase_description(step_key)
|
||||
trace_df["phase"] = step_key
|
||||
trace_df["phase_desc"] = get_phase_description(step_key)
|
||||
|
||||
# Combine all data frames so they can be put in a single plot
|
||||
traces_df = pd.concat(trace_dfs)
|
||||
@ -492,17 +501,23 @@ def main(
|
||||
|
||||
def make_plot_title_suffix(profile_json: dict) -> str:
|
||||
context = profile_json["context"]
|
||||
sparsity = context.get('sparsity', None)
|
||||
run_type = \
|
||||
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
|
||||
(f'Complete {context["complete_num_requests_per_step"]} per '
|
||||
f'step; Run till completion')
|
||||
return (f"{context['engine_args']['model']}\n"
|
||||
f"Batch={context['batch_size']}, "
|
||||
f"PromptLen={context['prompt_len']}, "
|
||||
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
|
||||
f"Run Type: {run_type}")
|
||||
sparsity = context.get("sparsity", None)
|
||||
run_type = (
|
||||
f"Run {context['num_steps']} steps"
|
||||
if context["num_steps"]
|
||||
else (
|
||||
f"Complete {context['complete_num_requests_per_step']} per "
|
||||
f"step; Run till completion"
|
||||
)
|
||||
)
|
||||
return (
|
||||
f"{context['engine_args']['model']}\n"
|
||||
f"Batch={context['batch_size']}, "
|
||||
f"PromptLen={context['prompt_len']}, "
|
||||
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
|
||||
f"Run Type: {run_type}"
|
||||
)
|
||||
|
||||
profile_json = None
|
||||
with open(json_trace) as f:
|
||||
@ -511,14 +526,14 @@ def main(
|
||||
|
||||
# Get all `llm.generate.step()` profile
|
||||
step_traces = list(profile_json.keys())
|
||||
assert (step_traces[0] == 'context')
|
||||
assert step_traces[0] == "context"
|
||||
step_traces = step_traces[1:] # have only prefill and decodes
|
||||
prefills = list(filter(lambda x: "prefill" in x, step_traces))
|
||||
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
|
||||
assert len(prefills) + len(all_decodes) == len(step_traces)
|
||||
assert len(prefills) == 1
|
||||
|
||||
decodes = all_decodes[::args.step_plot_interval]
|
||||
decodes = all_decodes[:: args.step_plot_interval]
|
||||
if decodes[-1] != all_decodes[-1]:
|
||||
# Always have the last decode
|
||||
decodes.append(all_decodes[-1])
|
||||
@ -528,48 +543,63 @@ def main(
|
||||
|
||||
plot_title_suffix = make_plot_title_suffix(profile_json)
|
||||
|
||||
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
|
||||
output_directory / Path("prefill.png"))
|
||||
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
|
||||
output_directory / Path("decode_steps.png"))
|
||||
plot_trace_df(
|
||||
prefill_traces,
|
||||
plot_metric,
|
||||
"prefill " + plot_title_suffix,
|
||||
output_directory / Path("prefill.png"),
|
||||
)
|
||||
plot_trace_df(
|
||||
decode_traces,
|
||||
plot_metric,
|
||||
"decodes " + plot_title_suffix,
|
||||
output_directory / Path("decode_steps.png"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by \
|
||||
examples/offline_inference/profiling.py")
|
||||
parser.add_argument("--output-directory",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Directory to output plots")
|
||||
parser.add_argument("--level",
|
||||
type=str,
|
||||
default="module",
|
||||
choices=["module", "kernel"])
|
||||
parser.add_argument("--top-k",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Only graph the top `top_k` entries by time.")
|
||||
parser.add_argument("--fold-json-node",
|
||||
nargs='+',
|
||||
default=['Sampler', 'LogitsProcessor'],
|
||||
help='Do not plot the children of these nodes. Let, \
|
||||
parser.add_argument(
|
||||
"--json-trace",
|
||||
type=str,
|
||||
required=True,
|
||||
help="json trace file output by \
|
||||
examples/offline_inference/profiling.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-directory", type=str, required=False, help="Directory to output plots"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--level", type=str, default="module", choices=["module", "kernel"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Only graph the top `top_k` entries by time.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fold-json-node",
|
||||
nargs="+",
|
||||
default=["Sampler", "LogitsProcessor"],
|
||||
help="Do not plot the children of these nodes. Let, \
|
||||
the node represent the aggregate of all its \
|
||||
children')
|
||||
parser.add_argument("--plot-metric",
|
||||
type=str,
|
||||
default="cuda_time_ms",
|
||||
help='Metric to plot. some options are cuda_time_ms, \
|
||||
pct_cuda_time')
|
||||
children",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-metric",
|
||||
type=str,
|
||||
default="cuda_time_ms",
|
||||
help="Metric to plot. some options are cuda_time_ms, \
|
||||
pct_cuda_time",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step-plot-interval",
|
||||
type=int,
|
||||
default=4,
|
||||
help="For every `step_plot_interval` steps, plot 1 step")
|
||||
help="For every `step_plot_interval` steps, plot 1 step",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -583,11 +613,19 @@ if __name__ == "__main__":
|
||||
else:
|
||||
raise Exception(f"Unexpected level value ({args.level})")
|
||||
|
||||
output_directory = args.output_directory if args.output_directory else Path(
|
||||
args.json_trace).parent
|
||||
output_directory = (
|
||||
args.output_directory if args.output_directory else Path(args.json_trace).parent
|
||||
)
|
||||
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
|
||||
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
|
||||
make_names_unique, args.top_k, args.fold_json_node)
|
||||
main(
|
||||
Path(args.json_trace),
|
||||
output_directory,
|
||||
depth,
|
||||
args.plot_metric,
|
||||
make_names_unique,
|
||||
args.top_k,
|
||||
args.fold_json_node,
|
||||
)
|
||||
|
||||
@ -83,9 +83,9 @@ class Target:
|
||||
"""
|
||||
# Allow for modest floating-point errors
|
||||
epsilon = 0.000002
|
||||
if (self.weighted_duration > self.Duration() + epsilon):
|
||||
print('{} > {}?'.format(self.weighted_duration, self.Duration()))
|
||||
assert (self.weighted_duration <= self.Duration() + epsilon)
|
||||
if self.weighted_duration > self.Duration() + epsilon:
|
||||
print("{} > {}?".format(self.weighted_duration, self.Duration()))
|
||||
assert self.weighted_duration <= self.Duration() + epsilon
|
||||
return self.weighted_duration
|
||||
|
||||
def DescribeTargets(self):
|
||||
@ -93,10 +93,10 @@ class Target:
|
||||
# Some build steps generate dozens of outputs - handle them sanely.
|
||||
# The max_length was chosen so that it can fit most of the long
|
||||
# single-target names, while minimizing word wrapping.
|
||||
result = ', '.join(self.targets)
|
||||
result = ", ".join(self.targets)
|
||||
max_length = 65
|
||||
if len(result) > max_length:
|
||||
result = result[:max_length] + '...'
|
||||
result = result[:max_length] + "..."
|
||||
return result
|
||||
|
||||
|
||||
@ -106,12 +106,13 @@ def ReadTargets(log, show_all):
|
||||
|
||||
The result is a list of Target objects."""
|
||||
header = log.readline()
|
||||
assert header == '# ninja log v5\n', \
|
||||
'unrecognized ninja log version {!r}'.format(header)
|
||||
assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format(
|
||||
header
|
||||
)
|
||||
targets_dict = {}
|
||||
last_end_seen = 0.0
|
||||
for line in log:
|
||||
parts = line.strip().split('\t')
|
||||
parts = line.strip().split("\t")
|
||||
if len(parts) != 5:
|
||||
# If ninja.exe is rudely halted then the .ninja_log file may be
|
||||
# corrupt. Silently continue.
|
||||
@ -150,17 +151,17 @@ def ReadTargets(log, show_all):
|
||||
def GetExtension(target, extra_patterns):
|
||||
"""Return the file extension that best represents a target.
|
||||
|
||||
For targets that generate multiple outputs it is important to return a
|
||||
consistent 'canonical' extension. Ultimately the goal is to group build steps
|
||||
by type."""
|
||||
For targets that generate multiple outputs it is important to return a
|
||||
consistent 'canonical' extension. Ultimately the goal is to group build steps
|
||||
by type."""
|
||||
for output in target.targets:
|
||||
if extra_patterns:
|
||||
for fn_pattern in extra_patterns.split(';'):
|
||||
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
|
||||
for fn_pattern in extra_patterns.split(";"):
|
||||
if fnmatch.fnmatch(output, "*" + fn_pattern + "*"):
|
||||
return fn_pattern
|
||||
# Not a true extension, but a good grouping.
|
||||
if output.endswith('type_mappings'):
|
||||
extension = 'type_mappings'
|
||||
if output.endswith("type_mappings"):
|
||||
extension = "type_mappings"
|
||||
break
|
||||
|
||||
# Capture two extensions if present. For example: file.javac.jar should
|
||||
@ -170,26 +171,26 @@ def GetExtension(target, extra_patterns):
|
||||
extension = ext2 + ext1 # Preserve the order in the file name.
|
||||
|
||||
if len(extension) == 0:
|
||||
extension = '(no extension found)'
|
||||
extension = "(no extension found)"
|
||||
|
||||
if ext1 in ['.pdb', '.dll', '.exe']:
|
||||
extension = 'PEFile (linking)'
|
||||
if ext1 in [".pdb", ".dll", ".exe"]:
|
||||
extension = "PEFile (linking)"
|
||||
# Make sure that .dll and .exe are grouped together and that the
|
||||
# .dll.lib files don't cause these to be listed as libraries
|
||||
break
|
||||
if ext1 in ['.so', '.TOC']:
|
||||
extension = '.so (linking)'
|
||||
if ext1 in [".so", ".TOC"]:
|
||||
extension = ".so (linking)"
|
||||
# Attempt to identify linking, avoid identifying as '.TOC'
|
||||
break
|
||||
# Make sure .obj files don't get categorized as mojo files
|
||||
if ext1 in ['.obj', '.o']:
|
||||
if ext1 in [".obj", ".o"]:
|
||||
break
|
||||
# Jars are the canonical output of java targets.
|
||||
if ext1 == '.jar':
|
||||
if ext1 == ".jar":
|
||||
break
|
||||
# Normalize all mojo related outputs to 'mojo'.
|
||||
if output.count('.mojom') > 0:
|
||||
extension = 'mojo'
|
||||
if output.count(".mojom") > 0:
|
||||
extension = "mojo"
|
||||
break
|
||||
return extension
|
||||
|
||||
@ -214,8 +215,8 @@ def SummarizeEntries(entries, extra_step_types):
|
||||
if target.end > latest:
|
||||
latest = target.end
|
||||
total_cpu_time += target.Duration()
|
||||
task_start_stop_times.append((target.start, 'start', target))
|
||||
task_start_stop_times.append((target.end, 'stop', target))
|
||||
task_start_stop_times.append((target.start, "start", target))
|
||||
task_start_stop_times.append((target.end, "stop", target))
|
||||
length = latest - earliest
|
||||
weighted_total = 0.0
|
||||
|
||||
@ -241,10 +242,10 @@ def SummarizeEntries(entries, extra_step_types):
|
||||
if num_running > 0:
|
||||
# Update the total weighted time up to this moment.
|
||||
last_weighted_time += (time - last_time) / float(num_running)
|
||||
if action_name == 'start':
|
||||
if action_name == "start":
|
||||
# Record the total weighted task time when this task starts.
|
||||
running_tasks[target] = last_weighted_time
|
||||
if action_name == 'stop':
|
||||
if action_name == "stop":
|
||||
# Record the change in the total weighted task time while this task
|
||||
# ran.
|
||||
weighted_duration = last_weighted_time - running_tasks[target]
|
||||
@ -252,13 +253,16 @@ def SummarizeEntries(entries, extra_step_types):
|
||||
weighted_total += weighted_duration
|
||||
del running_tasks[target]
|
||||
last_time = time
|
||||
assert (len(running_tasks) == 0)
|
||||
assert len(running_tasks) == 0
|
||||
|
||||
# Warn if the sum of weighted times is off by more than half a second.
|
||||
if abs(length - weighted_total) > 500:
|
||||
print('Warning: Possible corrupt ninja log, results may be '
|
||||
'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format(
|
||||
length, weighted_total))
|
||||
print(
|
||||
"Warning: Possible corrupt ninja log, results may be "
|
||||
"untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format(
|
||||
length, weighted_total
|
||||
)
|
||||
)
|
||||
|
||||
entries_by_ext = defaultdict(list)
|
||||
for target in entries:
|
||||
@ -266,32 +270,38 @@ def SummarizeEntries(entries, extra_step_types):
|
||||
entries_by_ext[extension].append(target)
|
||||
|
||||
for key, values in entries_by_ext.items():
|
||||
print(' Longest build steps for {}:'.format(key))
|
||||
print(" Longest build steps for {}:".format(key))
|
||||
values.sort(key=lambda x: x.WeightedDuration())
|
||||
for target in values[-long_count:]:
|
||||
print(
|
||||
' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'.
|
||||
format(target.WeightedDuration(), target.DescribeTargets(),
|
||||
target.Duration()))
|
||||
" {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format(
|
||||
target.WeightedDuration(),
|
||||
target.DescribeTargets(),
|
||||
target.Duration(),
|
||||
)
|
||||
)
|
||||
|
||||
print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x '
|
||||
'parallelism)'.format(length, total_cpu_time,
|
||||
total_cpu_time * 1.0 / length))
|
||||
print(' {} build steps completed, average of {:1.2f}/s'.format(
|
||||
len(entries),
|
||||
len(entries) / (length)))
|
||||
print(
|
||||
" {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x "
|
||||
"parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length)
|
||||
)
|
||||
print(
|
||||
" {} build steps completed, average of {:1.2f}/s".format(
|
||||
len(entries), len(entries) / (length)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
log_file = '.ninja_log'
|
||||
log_file = ".ninja_log"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-C', dest='build_directory', help='Build directory.')
|
||||
parser.add_argument("-C", dest="build_directory", help="Build directory.")
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--step-types',
|
||||
help='semicolon separated fnmatch patterns for build-step grouping')
|
||||
parser.add_argument('--log-file',
|
||||
help="specific ninja log file to analyze.")
|
||||
"-s",
|
||||
"--step-types",
|
||||
help="semicolon separated fnmatch patterns for build-step grouping",
|
||||
)
|
||||
parser.add_argument("--log-file", help="specific ninja log file to analyze.")
|
||||
args, _extra_args = parser.parse_known_args()
|
||||
if args.build_directory:
|
||||
log_file = os.path.join(args.build_directory, log_file)
|
||||
@ -300,17 +310,16 @@ def main():
|
||||
if args.step_types:
|
||||
# Make room for the extra build types.
|
||||
global long_ext_count
|
||||
long_ext_count += len(args.step_types.split(';'))
|
||||
long_ext_count += len(args.step_types.split(";"))
|
||||
|
||||
try:
|
||||
with open(log_file) as log:
|
||||
entries = ReadTargets(log, False)
|
||||
SummarizeEntries(entries, args.step_types)
|
||||
except OSError:
|
||||
print('Log file {!r} not found, no build summary created.'.format(
|
||||
log_file))
|
||||
print("Log file {!r} not found, no build summary created.".format(log_file))
|
||||
return errno.ENOENT
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@ -38,10 +38,12 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)):
|
||||
if (
|
||||
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)
|
||||
):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
@ -61,25 +63,27 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
|
||||
|
||||
|
||||
class ConfigValidator(ast.NodeVisitor):
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self): ...
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# Validate class with both @config and @dataclass decorators
|
||||
decorators = [
|
||||
id for d in node.decorator_list if (isinstance(d, ast.Name) and (
|
||||
(id := d.id) == 'config' or id == 'dataclass')) or
|
||||
(isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and
|
||||
(id := d.func.id) == 'dataclass'))
|
||||
id
|
||||
for d in node.decorator_list
|
||||
if (
|
||||
isinstance(d, ast.Name)
|
||||
and ((id := d.id) == "config" or id == "dataclass")
|
||||
)
|
||||
or (
|
||||
isinstance(d, ast.Call)
|
||||
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
|
||||
)
|
||||
]
|
||||
|
||||
if set(decorators) == {'config', 'dataclass'}:
|
||||
if set(decorators) == {"config", "dataclass"}:
|
||||
validate_class(node)
|
||||
elif set(decorators) == {'config'}:
|
||||
fail(
|
||||
f"Class {node.name} with config decorator must be a dataclass.",
|
||||
node)
|
||||
elif set(decorators) == {"config"}:
|
||||
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
@ -93,9 +97,11 @@ def validate_class(class_node: ast.ClassDef):
|
||||
# Skip ClassVar and InitVar
|
||||
# see https://docs.python.org/3/library/dataclasses.html#class-variables
|
||||
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
|
||||
if (isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id in {"ClassVar", "InitVar"}):
|
||||
if (
|
||||
isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id in {"ClassVar", "InitVar"}
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(stmt.target, ast.Name):
|
||||
@ -103,22 +109,30 @@ def validate_class(class_node: ast.ClassDef):
|
||||
if stmt.value is None:
|
||||
fail(
|
||||
f"Field '{field_name}' in {class_node.name} must have "
|
||||
"a default value.", stmt)
|
||||
"a default value.",
|
||||
stmt,
|
||||
)
|
||||
|
||||
if field_name not in attr_docs:
|
||||
fail(
|
||||
f"Field '{field_name}' in {class_node.name} must have "
|
||||
"a docstring.", stmt)
|
||||
"a docstring.",
|
||||
stmt,
|
||||
)
|
||||
|
||||
if isinstance(stmt.annotation, ast.Subscript) and \
|
||||
isinstance(stmt.annotation.value, ast.Name) \
|
||||
and stmt.annotation.value.id == "Union" and \
|
||||
isinstance(stmt.annotation.slice, ast.Tuple):
|
||||
if (
|
||||
isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id == "Union"
|
||||
and isinstance(stmt.annotation.slice, ast.Tuple)
|
||||
):
|
||||
args = stmt.annotation.slice.elts
|
||||
literal_args = [
|
||||
arg for arg in args
|
||||
if isinstance(arg, ast.Subscript) and isinstance(
|
||||
arg.value, ast.Name) and arg.value.id == "Literal"
|
||||
arg
|
||||
for arg in args
|
||||
if isinstance(arg, ast.Subscript)
|
||||
and isinstance(arg.value, ast.Name)
|
||||
and arg.value.id == "Literal"
|
||||
]
|
||||
if len(literal_args) > 1:
|
||||
fail(
|
||||
@ -126,7 +140,9 @@ def validate_class(class_node: ast.ClassDef):
|
||||
"use a single "
|
||||
"Literal type. Please use 'Literal[Literal1, "
|
||||
"Literal2]' instead of 'Union[Literal1, Literal2]'"
|
||||
".", stmt)
|
||||
".",
|
||||
stmt,
|
||||
)
|
||||
|
||||
|
||||
def validate_ast(tree: ast.stmt):
|
||||
|
||||
Reference in New Issue
Block a user