Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@ -17,12 +17,16 @@ REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
"vllm.env_override",
})
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
".version",
})
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset(
{
"vllm.env_override",
}
)
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset(
{
".version",
}
)
def _is_internal(name: str | None, *, level: int = 0) -> bool:
@ -34,8 +38,7 @@ def _is_internal(name: str | None, *, level: int = 0) -> bool:
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
file=sys.stderr)
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
@ -48,7 +51,6 @@ def main() -> None:
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
@ -56,10 +58,10 @@ def main() -> None:
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(
test.value, ast.Name):
guard_is_type_checking = (test.value.id == "typing"
and test.attr == "TYPE_CHECKING")
if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name):
guard_is_type_checking = (
test.value.id == "typing" and test.attr == "TYPE_CHECKING"
)
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
@ -79,24 +81,28 @@ def main() -> None:
return
for alias in node.names:
module_name = alias.name
if _is_internal(
module_name) and module_name not in ALLOWED_IMPORTS:
violations.append((
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
))
if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS:
violations.append(
(
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if _is_internal(
node.module, level=node.level
) and module_as_written not in ALLOWED_FROM_MODULES:
violations.append((
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
))
if (
_is_internal(node.module, level=node.level)
and module_as_written not in ALLOWED_FROM_MODULES
):
violations.append(
(
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
)
)
Visitor().visit(tree)

View File

@ -7,6 +7,7 @@ from enum import Enum
class SPDXStatus(Enum):
"""SPDX header status enumeration"""
EMPTY = "empty" # empty __init__.py
COMPLETE = "complete"
MISSING_LICENSE = "missing_license" # Only has copyright line
@ -16,7 +17,8 @@ class SPDXStatus(Enum):
FULL_SPDX_HEADER = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project")
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"
)
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
@ -123,8 +125,9 @@ def main():
continue
# Collect all files that need fixing
all_files_to_fix = (files_missing_both + files_missing_copyright +
files_missing_license)
all_files_to_fix = (
files_missing_both + files_missing_copyright + files_missing_license
)
if all_files_to_fix:
print("The following files are missing the SPDX header:")
if files_missing_both:

View File

@ -23,8 +23,7 @@ def is_allowed_file(current_file: str) -> bool:
def is_forbidden_import(line: str) -> bool:
stripped = line.strip()
return bool(
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
def parse_diff(diff: str) -> list[str]:
@ -42,24 +41,24 @@ def parse_diff(diff: str) -> list[str]:
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
current_lineno = int(
match.group(1)) - 1 # next "+ line" is here
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
elif line.startswith("+") and not line.startswith("++"):
current_lineno += 1
code_line = line[1:]
if is_forbidden_import(code_line):
violations.append(
f"{current_file}:{current_lineno}: {code_line.strip()}")
f"{current_file}:{current_lineno}: {code_line.strip()}"
)
return violations
def get_diff(diff_type: str) -> str:
if diff_type == "staged":
return subprocess.check_output(
["git", "diff", "--cached", "--unified=0"], text=True)
["git", "diff", "--cached", "--unified=0"], text=True
)
elif diff_type == "unstaged":
return subprocess.check_output(["git", "diff", "--unified=0"],
text=True)
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
else:
raise ValueError(f"Unknown diff_type: {diff_type}")
@ -75,8 +74,10 @@ def main():
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
if all_violations:
print("❌ Forbidden direct `import triton` detected."
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
print(
"❌ Forbidden direct `import triton` detected."
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
)
for v in all_violations:
print(f"{v}")
return 1

View File

@ -7,24 +7,23 @@ from pathlib import Path
import regex as re
FORBIDDEN_PATTERNS = re.compile(
r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)')
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
ALLOWED_PATTERNS = [
re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'),
re.compile(r'^\s*import\s+regex\s*$'),
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
re.compile(r"^\s*import\s+regex\s*$"),
]
def get_staged_python_files() -> list[str]:
try:
result = subprocess.run(
['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'],
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
capture_output=True,
text=True,
check=True)
files = result.stdout.strip().split(
'\n') if result.stdout.strip() else []
return [f for f in files if f.endswith('.py')]
check=True,
)
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
return [f for f in files if f.endswith(".py")]
except subprocess.CalledProcessError:
return []
@ -33,13 +32,14 @@ def is_forbidden_import(line: str) -> bool:
line = line.strip()
return bool(
FORBIDDEN_PATTERNS.match(line)
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS))
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
)
def check_file(filepath: str) -> list[tuple[int, str]]:
violations = []
try:
with open(filepath, encoding='utf-8') as f:
with open(filepath, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if is_forbidden_import(line):
violations.append((line_num, line.strip()))
@ -72,9 +72,7 @@ def main() -> int:
if total_violations > 0:
print(f"\n💡 Found {total_violations} violation(s).")
print("❌ Please replace 'import re' with 'import regex as re'")
print(
" Also replace 'from re import ...' with 'from regex import ...'"
) # noqa: E501
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
print("✅ Allowed imports:")
print(" - import regex as re")
print(" - import regex") # noqa: E501

View File

@ -12,8 +12,7 @@ try:
# most reliable source of truth for vLLM's build.
from torch.utils.cpp_extension import CUDA_HOME
except ImportError:
print("Warning: PyTorch not found. "
"Falling back to CUDA_HOME environment variable.")
print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.")
CUDA_HOME = os.environ.get("CUDA_HOME")
@ -27,8 +26,7 @@ def get_cpu_cores():
return multiprocessing.cpu_count()
def generate_presets(output_path="CMakeUserPresets.json",
force_overwrite=False):
def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False):
"""Generates the CMakeUserPresets.json file."""
print("Attempting to detect your system configuration...")
@ -39,8 +37,7 @@ def generate_presets(output_path="CMakeUserPresets.json",
prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc")
if os.path.exists(prospective_path):
nvcc_path = prospective_path
print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: "
f"{nvcc_path}")
print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}")
if not nvcc_path:
nvcc_path = which("nvcc")
@ -50,7 +47,8 @@ def generate_presets(output_path="CMakeUserPresets.json",
if not nvcc_path:
nvcc_path_input = input(
"Could not automatically find 'nvcc'. Please provide the full "
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ")
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): "
)
nvcc_path = nvcc_path_input.strip()
print(f"Using NVCC path: {nvcc_path}")
@ -63,12 +61,13 @@ def generate_presets(output_path="CMakeUserPresets.json",
"Could not automatically find Python executable. Please provide "
"the full path to your Python executable for vLLM development "
"(typically from your virtual environment, e.g., "
"/home/user/venvs/vllm/bin/python): ")
"/home/user/venvs/vllm/bin/python): "
)
python_executable = input(python_executable_prompt).strip()
if not python_executable:
raise ValueError(
"Could not determine Python executable. Please provide it "
"manually.")
"Could not determine Python executable. Please provide it manually."
)
print(f"Using Python executable: {python_executable}")
@ -76,20 +75,23 @@ def generate_presets(output_path="CMakeUserPresets.json",
cpu_cores = get_cpu_cores()
nvcc_threads = min(4, cpu_cores)
cmake_jobs = max(1, cpu_cores // nvcc_threads)
print(f"Detected {cpu_cores} CPU cores. "
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.")
print(
f"Detected {cpu_cores} CPU cores. "
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}."
)
# Get vLLM project root (assuming this script is in vllm/tools/)
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."))
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
print(f"VLLM project root detected as: {project_root}")
# Ensure python_executable path is absolute or resolvable
if not os.path.isabs(python_executable) and which(python_executable):
python_executable = os.path.abspath(which(python_executable))
elif not os.path.isabs(python_executable):
print(f"Warning: Python executable '{python_executable}' is not an "
"absolute path and not found in PATH. CMake might not find it.")
print(
f"Warning: Python executable '{python_executable}' is not an "
"absolute path and not found in PATH. CMake might not find it."
)
cache_variables = {
"CMAKE_CUDA_COMPILER": nvcc_path,
@ -122,24 +124,20 @@ def generate_presets(output_path="CMakeUserPresets.json",
configure_preset["generator"] = "Ninja"
cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}"
else:
print("Ninja not found, using default generator. "
"Build may be slower.")
print("Ninja not found, using default generator. Build may be slower.")
presets = {
"version":
6,
"version": 6,
# Keep in sync with CMakeLists.txt and requirements/build.txt
"cmakeMinimumRequired": {
"major": 3,
"minor": 26,
"patch": 1
},
"cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1},
"configurePresets": [configure_preset],
"buildPresets": [{
"name": "release",
"configurePreset": "release",
"jobs": cmake_jobs,
}],
"buildPresets": [
{
"name": "release",
"configurePreset": "release",
"jobs": cmake_jobs,
}
],
}
output_file_path = os.path.join(project_root, output_path)
@ -148,10 +146,12 @@ def generate_presets(output_path="CMakeUserPresets.json",
if force_overwrite:
print(f"Overwriting existing file '{output_file_path}'")
else:
overwrite = input(
f"'{output_file_path}' already exists. Overwrite? (y/N): "
).strip().lower()
if overwrite != 'y':
overwrite = (
input(f"'{output_file_path}' already exists. Overwrite? (y/N): ")
.strip()
.lower()
)
if overwrite != "y":
print("Generation cancelled.")
return
@ -160,11 +160,9 @@ def generate_presets(output_path="CMakeUserPresets.json",
json.dump(presets, f, indent=4)
print(f"Successfully generated '{output_file_path}'")
print("\nTo use this preset:")
print(
f"1. Ensure you are in the vLLM root directory: cd {project_root}")
print(f"1. Ensure you are in the vLLM root directory: cd {project_root}")
print("2. Initialize CMake: cmake --preset release")
print("3. Build+install: cmake --build --preset release "
"--target install")
print("3. Build+install: cmake --build --preset release --target install")
except OSError as e:
print(f"Error writing file: {e}")
@ -175,7 +173,7 @@ if __name__ == "__main__":
parser.add_argument(
"--force-overwrite",
action="store_true",
help="Force overwrite existing CMakeUserPresets.json without prompting"
help="Force overwrite existing CMakeUserPresets.json without prompting",
)
args = parser.parse_args()

View File

@ -17,44 +17,48 @@ import regex as re
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
# pickle
'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py',
'vllm/multimodal/hasher.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
'tests/utils_/test_utils.py',
'tests/tokenization/test_cached_tokenizer.py',
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/distributed/device_communicators/shm_object_storage.py',
'benchmarks/kernels/graph_machete_bench.py',
'benchmarks/kernels/benchmark_lora.py',
'benchmarks/kernels/benchmark_machete.py',
'benchmarks/fused_kernels/layernorm_rms_benchmarks.py',
'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py',
'benchmarks/cutlass_benchmarks/sparse_benchmarks.py',
"vllm/v1/serial_utils.py",
"vllm/v1/executor/multiproc_executor.py",
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"tests/utils_/test_utils.py",
"tests/tokenization/test_cached_tokenizer.py",
"vllm/distributed/utils.py",
"vllm/distributed/parallel_state.py",
"vllm/distributed/device_communicators/all_reduce_utils.py",
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py",
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
# cloudpickle
'vllm/executor/mp_distributed_executor.py',
'vllm/executor/ray_distributed_executor.py',
'vllm/entrypoints/llm.py',
'tests/utils.py',
"vllm/executor/mp_distributed_executor.py",
"vllm/executor/ray_distributed_executor.py",
"vllm/entrypoints/llm.py",
"tests/utils.py",
# pickle and cloudpickle
'vllm/utils/__init__.py',
"vllm/utils/__init__.py",
}
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)")
PICKLE_RE = re.compile(
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)"
)
def scan_file(path: str) -> int:
with open(path, encoding='utf-8') as f:
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
if PICKLE_RE.match(line):
print(f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import")
print(
f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import"
)
return 1
return 0
@ -92,13 +96,13 @@ def test_regex():
for i, (line, should_match) in enumerate(test_cases):
result = bool(PICKLE_RE.match(line))
assert result == should_match, (
f"Test case {i} failed: '{line}' "
f"(expected {should_match}, got {result})")
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
)
print("All regex tests passed.")
if __name__ == '__main__':
if '--test-regex' in sys.argv:
if __name__ == "__main__":
if "--test-regex" in sys.argv:
test_regex()
else:
sys.exit(main())

View File

@ -94,11 +94,15 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return file_groups
def mypy(targets: list[str], python_version: Optional[str],
follow_imports: Optional[str], file_group: str) -> int:
def mypy(
targets: list[str],
python_version: Optional[str],
follow_imports: Optional[str],
file_group: str,
) -> int:
"""
Run mypy on the given targets.
Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
@ -131,8 +135,9 @@ def main():
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(changed_files, python_version, follow_imports,
file_group)
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
)
return returncode

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This generates gpu kernel analysis output from nsys rep. Will call nsys
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
csv and html output for analysis
This generates gpu kernel analysis output from nsys rep. Will call nsys
stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate
csv and html output for analysis
"""
import argparse
import logging
import os
@ -16,13 +17,13 @@ logger = logging.getLogger(__name__)
# helper data class for annotating kernels
def load_engine_model():
""" returns engine_model built from all json files in the current dir """
"""returns engine_model built from all json files in the current dir"""
import glob
import json
engine_model = {}
json_files = glob.glob(
os.path.join(os.path.dirname(__file__) or ".", "*.json"))
json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json"))
for fname in json_files:
with open(fname, encoding="utf-8") as f:
engine_model.update(json.load(f))
@ -30,54 +31,54 @@ def load_engine_model():
class GPUTrace2Graph:
"""
Parses output of nsys report, generates csv and bar chart output
"""
Parses output of nsys report, generates csv and bar chart output
"""
def __init__(self):
import pandas as pd # avoid importing till needed
self.pd = pd
self.pd.options.mode.copy_on_write = True
# helper functions for generating trace->summary csvs
def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file):
logger.info('loading %s', in_file)
logger.info("loading %s", in_file)
df = self.pd.read_csv(
in_file,
usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name'])
df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)']
in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"]
)
df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"]
df = self.sum_non_overlapping_intervals(df)
# get ready to print table with elapsed times per kernel
df['Instances'] = 1
df_sum = df.groupby('Name', as_index=False).agg({
'Elapsed Time (ns)': 'sum',
'Duration (ns)': 'sum',
'Instances': 'size'
})
df["Instances"] = 1
df_sum = df.groupby("Name", as_index=False).agg(
{"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"}
)
# generate csv
df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9
df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9
df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False)
df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances',
'Name']].to_csv(out_file, index=False)
df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9
df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9
df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False)
df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv(
out_file, index=False
)
def sum_non_overlapping_intervals(self, df):
"""
returns new sorted df with Elapsed Time (ns) column using
vectorized operations
"""
returns new sorted df with Elapsed Time (ns) column using
vectorized operations
"""
logger.info("sorting %s trace records by start time", str(df.shape))
# Sort by start time and reset index
df = df.sort_values(by='Start (ns)').reset_index(drop=True)
df = df.sort_values(by="Start (ns)").reset_index(drop=True)
# Initialize elapsed time as duration
df['Elapsed Time (ns)'] = df['Duration (ns)']
df["Elapsed Time (ns)"] = df["Duration (ns)"]
# Get numpy arrays for faster operations
starts = df['Start (ns)'].values
ends = df['End (ns)'].values
starts = df["Start (ns)"].values
ends = df["End (ns)"].values
# Keep track of current interval end
current_end = ends[0]
@ -85,16 +86,17 @@ class GPUTrace2Graph:
# Update current_end for overlapping intervals
for i in range(1, len(df)):
if i % display_units == 0:
print(f'processing trace: {int(i/len(df) * 100)} %', end="\r")
print(f"processing trace: {int(i / len(df) * 100)} %", end="\r")
if starts[i] <= current_end:
if ends[i] > current_end:
# Partial overlap
df.iloc[i, df.columns.get_loc('Elapsed Time (ns)'
)] = ends[i] - current_end
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = (
ends[i] - current_end
)
current_end = ends[i]
else:
# Complete overlap
df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0
df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0
else:
# No overlap
current_end = ends[i]
@ -103,147 +105,167 @@ class GPUTrace2Graph:
# functions for generating html files
def make_html(self, df, output_dir, title):
""" make html graph from df """
"""make html graph from df"""
import plotly.express as px
if df.empty:
return
output_name = output_dir + '/result'
output_name = output_dir + "/result"
if not title:
title = 'Model_Engine'
x = 'Model_Engine'
y = 'Elapsed Time (sec)'
color = 'Category'
title = "Model_Engine"
x = "Model_Engine"
y = "Elapsed Time (sec)"
color = "Category"
""" generate kernel mapping table """
# Sort Model_Engine categories by last field after underscore
df['Model_Engine'] = self.pd.Categorical(
df['Model_Engine'],
sorted(df['Model_Engine'].unique(),
key=lambda x: x.split('_')[-1]))
df[['Model_Engine', color, 'Instances', 'Name',
y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False)
graph = px.histogram(df.round(2),
x=x,
y=y,
title=(f'{y} for {title}'),
color=color,
text_auto=True)
df["Model_Engine"] = self.pd.Categorical(
df["Model_Engine"],
sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]),
)
df[["Model_Engine", color, "Instances", "Name", y]].sort_values(
by=color
).to_csv(f"{output_name}.csv", index=False)
graph = px.histogram(
df.round(2),
x=x,
y=y,
title=(f"{y} for {title}"),
color=color,
text_auto=True,
)
# wrap x axis labels
graph.update_xaxes(automargin=True)
graph.write_html(f'{output_name}.html')
graph.write_html(f"{output_name}.html")
"""
Generate data table with columns per Model_Engine into result.html
"""
pivot_df = df.pivot_table(values='Elapsed Time (sec)',
index='Category',
columns='Model_Engine',
aggfunc='sum',
observed=False).round(2)
pivot_df = df.pivot_table(
values="Elapsed Time (sec)",
index="Category",
columns="Model_Engine",
aggfunc="sum",
observed=False,
).round(2)
# Add sum row at bottom
pivot_df.loc['total_elapsed_sec'] = pivot_df.sum()
pivot_df.fillna('').to_html('temp.html')
with (open(f'{output_name}.html', 'a', encoding='utf-8') as
outfile, open('temp.html', encoding='utf-8') as infile):
pivot_df.loc["total_elapsed_sec"] = pivot_df.sum()
pivot_df.fillna("").to_html("temp.html")
with (
open(f"{output_name}.html", "a", encoding="utf-8") as outfile,
open("temp.html", encoding="utf-8") as infile,
):
outfile.write(infile.read())
os.remove('temp.html')
os.remove("temp.html")
print(f'Finished generating: \n'
f' {output_name}.html for stack bar chart \n'
f' {output_name}.csv for Kernel-Category mapping')
print(
f"Finished generating: \n"
f" {output_name}.html for stack bar chart \n"
f" {output_name}.csv for Kernel-Category mapping"
)
def anno_gpu_kernname(self, df, mapping):
""" add "Category" column """
"""add "Category" column"""
def anno_gpu_kernname_helper(name):
for kern_name, val in mapping.items():
if re.search(kern_name, name):
return val
df['Category'] = df['Name'].apply(anno_gpu_kernname_helper)
df["Category"] = df["Name"].apply(anno_gpu_kernname_helper)
def make_nongpu_row(self, df, nongpu_sec):
""" this will append non-gpu time entry at end of df """
"""this will append non-gpu time entry at end of df"""
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)'
nongpu_row['Instances'] = 1
nongpu_row['Elapsed Time (sec)'] = nongpu_sec
return (nongpu_row)
nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)"
nongpu_row["Instances"] = 1
nongpu_row["Elapsed Time (sec)"] = nongpu_sec
return nongpu_row
def is_valid_file(self, base_file):
""" asserts if base_file is non-existent or is empty """
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \
f"{base_file} doesn't exist or is empty"
"""asserts if base_file is non-existent or is empty"""
assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, (
f"{base_file} doesn't exist or is empty"
)
def should_gen_file(self, new_file, base_file):
""" figure out if new file should be generated from base_file """
"""figure out if new file should be generated from base_file"""
self.is_valid_file(base_file)
if (os.path.exists(new_file)
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
and (os.path.getsize(base_file) > 0)):
logger.info('reusing %s', new_file)
if (
os.path.exists(new_file)
and (os.path.getmtime(new_file) > os.path.getmtime(base_file))
and (os.path.getsize(base_file) > 0)
):
logger.info("reusing %s", new_file)
return False
else:
logger.info('generating %s', new_file)
logger.info("generating %s", new_file)
return True
def gen_sum_file(self, file, nsys_cmd):
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
"""
import subprocess
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = '.'
file_dir = "."
# Walk through trace and get the total non-overlapped time
nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv'
sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv'
nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv"
sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv"
if self.should_gen_file(nsys_stats_file, file):
cmd = [
nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o',
f'{file_dir}/{file_name}'
nsys_cmd,
"stats",
"-r",
"cuda_gpu_trace",
file,
"-o",
f"{file_dir}/{file_name}",
]
cmd_str = ' '.join(cmd)
logger.info('+ %s', cmd_str)
cmd_str = " ".join(cmd)
logger.info("+ %s", cmd_str)
# estimate time based on calibrated 240M/min
file_size_mb = os.path.getsize(file) / 1e6
logger.info(
'nsys stats for %.2f MB file expected to take %.2f min',
file_size_mb, file_size_mb / 240)
"nsys stats for %.2f MB file expected to take %.2f min",
file_size_mb,
file_size_mb / 240,
)
try:
subprocess.run(cmd, check=True)
except Exception:
logger.error("%s failed; Use --nsys_cmd to specify nsys path",
cmd_str)
logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str)
exit(1)
logger.info('generating non-overalapped sum %s', sum_file)
logger.info("generating non-overalapped sum %s", sum_file)
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
self.is_valid_file(sum_file)
logger.info('Finished generating %s', sum_file)
logger.info("Finished generating %s", sum_file)
return sum_file
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
""" generates graph and csv file from in_file into out_dir """
"""generates graph and csv file from in_file into out_dir"""
# Initialize an empty DataFrame to store combined data
combined_df = self.pd.DataFrame()
for idx, (file, engine, model, total_sec) in enumerate(in_file):
file_dir = os.path.dirname(file)
file_name = os.path.basename(file)
if not file_dir:
file_dir = '.'
file_dir = "."
sum_file = self.gen_sum_file(file, nsys_cmd)
# read kernel summary file
df = self.pd.read_csv(sum_file)
# annotate kernel to their categories
assert engine_model.get(engine), f'engine {engine} unknown'
assert engine_model[engine].get(model), f'model {model} unknown'
assert engine_model.get(engine), f"engine {engine} unknown"
assert engine_model[engine].get(model), f"model {model} unknown"
# remove nsys-rep from file_name for shorter x-label
file_name = file_name.replace('.nsys-rep', '')
df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}'
file_name = file_name.replace(".nsys-rep", "")
df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}"
self.anno_gpu_kernname(df, engine_model[engine][model])
# patch in non-gpu time
gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1)
gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1)
total_sec = round(float(total_sec), 1)
if total_sec < gpu_sec:
logger.warning(
@ -256,7 +278,7 @@ class GPUTrace2Graph:
df = self.pd.concat([df, nongpu_row], ignore_index=True)
combined_df = self.pd.concat([combined_df, df], ignore_index=True)
if out_dir is None:
out_dir = '.'
out_dir = "."
else:
os.makedirs(out_dir, exist_ok=True)
# generate html file
@ -264,50 +286,59 @@ class GPUTrace2Graph:
def parse_tuple(s):
return tuple(s.split(','))
return tuple(s.split(","))
def main():
logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'),
level=logging.INFO)
logging.basicConfig(
format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO
)
parser = argparse.ArgumentParser(
description=(
'Process nsys rep and generate kernel non-overlapped cycles. \n'
'Example:\n'
"Process nsys rep and generate kernel non-overlapped cycles. \n"
"Example:\n"
"gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n"
"d2.nsys-rep,vllm,gpt-oss,102 "
"--out_dir results/ --title \"Model=gpt-oss vLLM chart\""),
formatter_class=argparse.RawDescriptionHelpFormatter)
'--out_dir results/ --title "Model=gpt-oss vLLM chart"'
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# load supported engine_model
engine_model_supported = load_engine_model()
# Get a string representation of supported engine/model combinations
engine_model_supported_str = ', '.join(
engine_model_supported_str = ", ".join(
f"{engine}:[{', '.join(models.keys())}]"
for engine, models in engine_model_supported.items())
for engine, models in engine_model_supported.items()
)
parser.add_argument(
'--in_file',
"--in_file",
type=parse_tuple,
nargs='+',
nargs="+",
help=(
'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) '
'separated by space. Elapsed_nonprofiled_sec is runtime without '
'profiling used to calculate non-gpu time. Specify 0 to use '
'elapsed time from nsys-rep but that might inflate non-gpu time. '
f'Available engine:[model] are: {engine_model_supported_str} '
f'Example: --infile d1.nsys-rep,vllm,llama,100 '
'd2.nsys-rep,vllm,gpt-oss,102'),
required=True)
parser.add_argument('--out_dir', help=('output dir for result.csv/html'))
parser.add_argument('--title', help=('title for html chart'))
parser.add_argument('--nsys_cmd',
help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'),
default="nsys")
"list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) "
"separated by space. Elapsed_nonprofiled_sec is runtime without "
"profiling used to calculate non-gpu time. Specify 0 to use "
"elapsed time from nsys-rep but that might inflate non-gpu time. "
f"Available engine:[model] are: {engine_model_supported_str} "
f"Example: --infile d1.nsys-rep,vllm,llama,100 "
"d2.nsys-rep,vllm,gpt-oss,102"
),
required=True,
)
parser.add_argument("--out_dir", help=("output dir for result.csv/html"))
parser.add_argument("--title", help=("title for html chart"))
parser.add_argument(
"--nsys_cmd",
help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"),
default="nsys",
)
args = parser.parse_args()
gputrace = GPUTrace2Graph()
gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd,
engine_model_supported)
gputrace.gen_graph(
args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -29,48 +29,50 @@ def flatten_entries(entry_cls, profile_dict: dict):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by "
"examples/offline_inference/profiling.py")
parser.add_argument("--phase",
type=str,
required=True,
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the "
"layerwise model table")
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_inference/profiling.py",
)
parser.add_argument(
"--phase",
type=str,
required=True,
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number",
)
parser.add_argument(
"--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the layerwise model table",
)
args = parser.parse_args()
with open(args.json_trace) as f:
profile_data = json.load(f)
assert args.phase in profile_data, \
(f"Cannot find phase {args.phase} in profile data. Choose one among"
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
assert args.phase in profile_data, (
f"Cannot find phase {args.phase} in profile data. Choose one among"
f"{[x for x in profile_data.keys() if 'prefill' in x or 'decode' in x]}"
) # noqa
if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
invocations=15)
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]
)
column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15)
elif args.table == "model":
entries_and_depths = flatten_entries(
ModelStatsEntry, profile_data[args.phase]["model_stats"])
column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
pct_cuda_time=12,
trace=60)
ModelStatsEntry, profile_data[args.phase]["model_stats"]
)
column_widths = dict(
name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
)
# indent entry names based on the depth
entries = []
@ -78,7 +80,8 @@ if __name__ == "__main__":
entry.name = indent_string(
entry.name,
indent=depth,
indent_style=lambda indent: "|" + "-" * indent + " ")
indent_style=lambda indent: "|" + "-" * indent + " ",
)
entries.append(entry)
TablePrinter(type(entries[0]), column_widths).print_table(entries)

View File

@ -18,17 +18,18 @@ import pandas as pd
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max([
largest_dist_from_leaf(child, depth=depth + 1)
for child in node["children"]
])
return max(
[largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]]
)
def get_entries_at_depth(depth: int,
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
def get_entries_at_depth(
depth: int,
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=(),
):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int,
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"], ) + trace
trace = (node["entry"]["name"],) + trace
for child in node["children"]:
get_entries_at_depth(depth,
entries_and_traces,
child,
curr_depth=curr_depth + 1,
trace=trace)
get_entries_at_depth(
depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace
)
def fold_nodes(root: dict, nodes_to_fold: list[str]):
stack: list[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
if node["entry"]["name"] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str:
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(
trim_string_back(abbreviate_known_names(t.get_text()),
max_char_len))
t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len))
def abbreviate_known_names(name: str) -> str:
@ -108,15 +104,21 @@ def attempt_to_make_names_unique(entries_and_traces):
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name]
entries_and_traces_with_name = [
(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name
]
zipped_traces = list(
zip(*[trace for _, trace in entries_and_traces_with_name]))
zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(i for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)), None)
(
i
for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)
),
None,
)
if first_trace_difference is None:
# can't create a unique name, leave the names as they
@ -124,34 +126,32 @@ def attempt_to_make_names_unique(entries_and_traces):
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join((entry["name"], ) +
trace[:first_trace_difference + 1])
entry["name"] = " <- ".join(
(entry["name"],) + trace[: first_trace_difference + 1]
)
## Operation grouping utils ####
'''
"""
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
'''
"""
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or \
"reshape_and_cache_flash_kernel" in op_name:
if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or \
"scaled_int8_quant" in op_name:
if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name:
return True
# LoRA ops
@ -168,24 +168,27 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
return "bgmv_expand" in op_name
def is_cutlass_gemm_op(op_name: str):
return "void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name
return (
"void cutlass::Kernel" in op_name
or "void cutlass::device_kernel" in op_name
)
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
return is_cutlass_gemm_op(op_name) or \
"xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"s16816gemm" in op_name
return (
is_cutlass_gemm_op(op_name)
or "xmma_gemm" in op_name
or "gemv2T_kernel" in op_name
or "splitKreduce" in op_name
or "s16816gemm" in op_name
)
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or \
"memset" in op_name.lower()
return "memcpy" in op_name.lower() or "memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
@ -195,17 +198,15 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())
return is_nccl_op(op_name) and (
"all_reduce" in op_name.lower() or "allreduce" in op_name.lower()
)
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()
return is_nccl_op(op_name) and "gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()
return is_nccl_op(op_name) and "broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
@ -269,114 +270,122 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
filter(lambda x: is_cross_device_reduce_1stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
filter(lambda x: is_cross_device_reduce_2stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_ops = list(
filter(lambda x: is_custom_ar_all_reduce(x), ops))
custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1)
if len(sgmv_shrink_ops):
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
axis=1)
trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1)
if len(sgmv_expand_ops):
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
axis=1)
trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1)
if len(bgmv_shrink_ops):
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
axis=1)
trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1)
if len(bgmv_expand_ops):
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
axis=1)
trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1)
if len(cutlass_gemm_ops):
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
axis=1)
trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
axis=1)
trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1)
if len(mem_ops):
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)
trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1
)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)
trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
trace_df["cross_device_reduce_1stage_ops"] = trace_df[
cross_device_reduce_1stage_ops
].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
trace_df["cross_device_reduce_2stage_ops"] = trace_df[
cross_device_reduce_2stage_ops
].agg("sum", axis=1)
if len(custom_ar_all_reduce_ops):
trace_df['custom_ar_all_reduce_ops'] = trace_df[
custom_ar_all_reduce_ops].agg("sum", axis=1)
trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg(
"sum", axis=1
)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1)
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
trace_df.drop(
attention_ops
+ quant_ops
+ sgmv_shrink_ops
+ sgmv_expand_ops
+ bgmv_shrink_ops
+ bgmv_expand_ops
+ cutlass_gemm_ops
+ gemm_ops
+ rms_norm_ops
+ vocab_embed_ops
+ mem_ops
+ elementwise_ops
+ nccl_all_reduce_ops
+ nccl_gather_ops
+ nccl_broadcast_ops
+ nccl_other_ops
+ cross_device_reduce_1stage_ops
+ cross_device_reduce_2stage_ops
+ custom_ar_all_reduce_ops
+ reduce_kernel_ops,
axis=1,
inplace=True,
)
return trace_df
## Data plotting utils ####
def plot_trace_df(traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None):
def plot_trace_df(
traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None,
):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df['phase_desc'].to_list()
descs = phase_df["phase_desc"].to_list()
assert all([desc == descs[0] for desc in descs])
return descs[0]
phases = traces_df['phase'].unique()
phases = traces_df["phase"].unique()
phase_descs = [get_phase_description(traces_df, p) for p in phases]
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
aggfunc="sum")
traces_df = traces_df.pivot_table(
index="phase", columns="name", values=plot_metric, aggfunc="sum"
)
traces_df = group_trace_by_operations(traces_df)
@ -396,20 +405,19 @@ def plot_trace_df(traces_df: pd.DataFrame,
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha='center',
color='w',
weight='bold',
size=5)
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha="center",
color="w",
weight="bold",
size=5,
)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles,
labels,
loc='center left',
bbox_to_anchor=(1, 1))
legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
@ -417,21 +425,20 @@ def plot_trace_df(traces_df: pd.DataFrame,
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches='tight')
plt.savefig(output, bbox_inches="tight")
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: list[str]):
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: list[str],
):
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
@ -441,16 +448,14 @@ def main(
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(df: pd.DataFrame,
metric: str,
top_k: int = 9) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
["name"]] = "others"
def keep_only_top_entries(
df: pd.DataFrame, metric: str, top_k: int = 9
) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
return df
def get_phase_description(key: str) -> str:
num_running_seqs = profile_json[key]['metadata'][
'num_running_seqs']
num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"]
if num_running_seqs is not None:
return f"{key}-seqs-{num_running_seqs}"
else:
@ -466,20 +471,24 @@ def main(
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
traces))
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces)
)
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k), trace_dfs))
trace_df, "cuda_time_us", top_k
),
trace_dfs,
)
)
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
trace_df['phase_desc'] = get_phase_description(step_key)
trace_df["phase"] = step_key
trace_df["phase_desc"] = get_phase_description(step_key)
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
@ -492,17 +501,23 @@ def main(
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
run_type = \
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
(f'Complete {context["complete_num_requests_per_step"]} per '
f'step; Run till completion')
return (f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}")
sparsity = context.get("sparsity", None)
run_type = (
f"Run {context['num_steps']} steps"
if context["num_steps"]
else (
f"Complete {context['complete_num_requests_per_step']} per "
f"step; Run till completion"
)
)
return (
f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}"
)
profile_json = None
with open(json_trace) as f:
@ -511,14 +526,14 @@ def main(
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert (step_traces[0] == 'context')
assert step_traces[0] == "context"
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[::args.step_plot_interval]
decodes = all_decodes[:: args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
@ -528,48 +543,63 @@ def main(
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
output_directory / Path("prefill.png"))
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"))
plot_trace_df(
prefill_traces,
plot_metric,
"prefill " + plot_title_suffix,
output_directory / Path("prefill.png"),
)
plot_trace_df(
decode_traces,
plot_metric,
"decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py")
parser.add_argument("--output-directory",
type=str,
required=False,
help="Directory to output plots")
parser.add_argument("--level",
type=str,
default="module",
choices=["module", "kernel"])
parser.add_argument("--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.")
parser.add_argument("--fold-json-node",
nargs='+',
default=['Sampler', 'LogitsProcessor'],
help='Do not plot the children of these nodes. Let, \
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py",
)
parser.add_argument(
"--output-directory", type=str, required=False, help="Directory to output plots"
)
parser.add_argument(
"--level", type=str, default="module", choices=["module", "kernel"]
)
parser.add_argument(
"--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.",
)
parser.add_argument(
"--fold-json-node",
nargs="+",
default=["Sampler", "LogitsProcessor"],
help="Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children')
parser.add_argument("--plot-metric",
type=str,
default="cuda_time_ms",
help='Metric to plot. some options are cuda_time_ms, \
pct_cuda_time')
children",
)
parser.add_argument(
"--plot-metric",
type=str,
default="cuda_time_ms",
help="Metric to plot. some options are cuda_time_ms, \
pct_cuda_time",
)
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step")
help="For every `step_plot_interval` steps, plot 1 step",
)
args = parser.parse_args()
@ -583,11 +613,19 @@ if __name__ == "__main__":
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = args.output_directory if args.output_directory else Path(
args.json_trace).parent
output_directory = (
args.output_directory if args.output_directory else Path(args.json_trace).parent
)
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
make_names_unique, args.top_k, args.fold_json_node)
main(
Path(args.json_trace),
output_directory,
depth,
args.plot_metric,
make_names_unique,
args.top_k,
args.fold_json_node,
)

View File

@ -83,9 +83,9 @@ class Target:
"""
# Allow for modest floating-point errors
epsilon = 0.000002
if (self.weighted_duration > self.Duration() + epsilon):
print('{} > {}?'.format(self.weighted_duration, self.Duration()))
assert (self.weighted_duration <= self.Duration() + epsilon)
if self.weighted_duration > self.Duration() + epsilon:
print("{} > {}?".format(self.weighted_duration, self.Duration()))
assert self.weighted_duration <= self.Duration() + epsilon
return self.weighted_duration
def DescribeTargets(self):
@ -93,10 +93,10 @@ class Target:
# Some build steps generate dozens of outputs - handle them sanely.
# The max_length was chosen so that it can fit most of the long
# single-target names, while minimizing word wrapping.
result = ', '.join(self.targets)
result = ", ".join(self.targets)
max_length = 65
if len(result) > max_length:
result = result[:max_length] + '...'
result = result[:max_length] + "..."
return result
@ -106,12 +106,13 @@ def ReadTargets(log, show_all):
The result is a list of Target objects."""
header = log.readline()
assert header == '# ninja log v5\n', \
'unrecognized ninja log version {!r}'.format(header)
assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format(
header
)
targets_dict = {}
last_end_seen = 0.0
for line in log:
parts = line.strip().split('\t')
parts = line.strip().split("\t")
if len(parts) != 5:
# If ninja.exe is rudely halted then the .ninja_log file may be
# corrupt. Silently continue.
@ -150,17 +151,17 @@ def ReadTargets(log, show_all):
def GetExtension(target, extra_patterns):
"""Return the file extension that best represents a target.
For targets that generate multiple outputs it is important to return a
consistent 'canonical' extension. Ultimately the goal is to group build steps
by type."""
For targets that generate multiple outputs it is important to return a
consistent 'canonical' extension. Ultimately the goal is to group build steps
by type."""
for output in target.targets:
if extra_patterns:
for fn_pattern in extra_patterns.split(';'):
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
for fn_pattern in extra_patterns.split(";"):
if fnmatch.fnmatch(output, "*" + fn_pattern + "*"):
return fn_pattern
# Not a true extension, but a good grouping.
if output.endswith('type_mappings'):
extension = 'type_mappings'
if output.endswith("type_mappings"):
extension = "type_mappings"
break
# Capture two extensions if present. For example: file.javac.jar should
@ -170,26 +171,26 @@ def GetExtension(target, extra_patterns):
extension = ext2 + ext1 # Preserve the order in the file name.
if len(extension) == 0:
extension = '(no extension found)'
extension = "(no extension found)"
if ext1 in ['.pdb', '.dll', '.exe']:
extension = 'PEFile (linking)'
if ext1 in [".pdb", ".dll", ".exe"]:
extension = "PEFile (linking)"
# Make sure that .dll and .exe are grouped together and that the
# .dll.lib files don't cause these to be listed as libraries
break
if ext1 in ['.so', '.TOC']:
extension = '.so (linking)'
if ext1 in [".so", ".TOC"]:
extension = ".so (linking)"
# Attempt to identify linking, avoid identifying as '.TOC'
break
# Make sure .obj files don't get categorized as mojo files
if ext1 in ['.obj', '.o']:
if ext1 in [".obj", ".o"]:
break
# Jars are the canonical output of java targets.
if ext1 == '.jar':
if ext1 == ".jar":
break
# Normalize all mojo related outputs to 'mojo'.
if output.count('.mojom') > 0:
extension = 'mojo'
if output.count(".mojom") > 0:
extension = "mojo"
break
return extension
@ -214,8 +215,8 @@ def SummarizeEntries(entries, extra_step_types):
if target.end > latest:
latest = target.end
total_cpu_time += target.Duration()
task_start_stop_times.append((target.start, 'start', target))
task_start_stop_times.append((target.end, 'stop', target))
task_start_stop_times.append((target.start, "start", target))
task_start_stop_times.append((target.end, "stop", target))
length = latest - earliest
weighted_total = 0.0
@ -241,10 +242,10 @@ def SummarizeEntries(entries, extra_step_types):
if num_running > 0:
# Update the total weighted time up to this moment.
last_weighted_time += (time - last_time) / float(num_running)
if action_name == 'start':
if action_name == "start":
# Record the total weighted task time when this task starts.
running_tasks[target] = last_weighted_time
if action_name == 'stop':
if action_name == "stop":
# Record the change in the total weighted task time while this task
# ran.
weighted_duration = last_weighted_time - running_tasks[target]
@ -252,13 +253,16 @@ def SummarizeEntries(entries, extra_step_types):
weighted_total += weighted_duration
del running_tasks[target]
last_time = time
assert (len(running_tasks) == 0)
assert len(running_tasks) == 0
# Warn if the sum of weighted times is off by more than half a second.
if abs(length - weighted_total) > 500:
print('Warning: Possible corrupt ninja log, results may be '
'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format(
length, weighted_total))
print(
"Warning: Possible corrupt ninja log, results may be "
"untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format(
length, weighted_total
)
)
entries_by_ext = defaultdict(list)
for target in entries:
@ -266,32 +270,38 @@ def SummarizeEntries(entries, extra_step_types):
entries_by_ext[extension].append(target)
for key, values in entries_by_ext.items():
print(' Longest build steps for {}:'.format(key))
print(" Longest build steps for {}:".format(key))
values.sort(key=lambda x: x.WeightedDuration())
for target in values[-long_count:]:
print(
' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'.
format(target.WeightedDuration(), target.DescribeTargets(),
target.Duration()))
" {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format(
target.WeightedDuration(),
target.DescribeTargets(),
target.Duration(),
)
)
print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x '
'parallelism)'.format(length, total_cpu_time,
total_cpu_time * 1.0 / length))
print(' {} build steps completed, average of {:1.2f}/s'.format(
len(entries),
len(entries) / (length)))
print(
" {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x "
"parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length)
)
print(
" {} build steps completed, average of {:1.2f}/s".format(
len(entries), len(entries) / (length)
)
)
def main():
log_file = '.ninja_log'
log_file = ".ninja_log"
parser = argparse.ArgumentParser()
parser.add_argument('-C', dest='build_directory', help='Build directory.')
parser.add_argument("-C", dest="build_directory", help="Build directory.")
parser.add_argument(
'-s',
'--step-types',
help='semicolon separated fnmatch patterns for build-step grouping')
parser.add_argument('--log-file',
help="specific ninja log file to analyze.")
"-s",
"--step-types",
help="semicolon separated fnmatch patterns for build-step grouping",
)
parser.add_argument("--log-file", help="specific ninja log file to analyze.")
args, _extra_args = parser.parse_known_args()
if args.build_directory:
log_file = os.path.join(args.build_directory, log_file)
@ -300,17 +310,16 @@ def main():
if args.step_types:
# Make room for the extra build types.
global long_ext_count
long_ext_count += len(args.step_types.split(';'))
long_ext_count += len(args.step_types.split(";"))
try:
with open(log_file) as log:
entries = ReadTargets(log, False)
SummarizeEntries(entries, args.step_types)
except OSError:
print('Log file {!r} not found, no build summary created.'.format(
log_file))
print("Log file {!r} not found, no build summary created.".format(log_file))
return errno.ENOENT
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -38,10 +38,12 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
if (
not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)
):
continue
doc = inspect.cleandoc(b.value.value)
@ -61,25 +63,27 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
class ConfigValidator(ast.NodeVisitor):
def __init__(self):
...
def __init__(self): ...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id for d in node.decorator_list if (isinstance(d, ast.Name) and (
(id := d.id) == 'config' or id == 'dataclass')) or
(isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and
(id := d.func.id) == 'dataclass'))
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
if set(decorators) == {'config', 'dataclass'}:
if set(decorators) == {"config", "dataclass"}:
validate_class(node)
elif set(decorators) == {'config'}:
fail(
f"Class {node.name} with config decorator must be a dataclass.",
node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
self.generic_visit(node)
@ -93,9 +97,11 @@ def validate_class(class_node: ast.ClassDef):
# Skip ClassVar and InitVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
if (isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id in {"ClassVar", "InitVar"}):
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id in {"ClassVar", "InitVar"}
):
continue
if isinstance(stmt.target, ast.Name):
@ -103,22 +109,30 @@ def validate_class(class_node: ast.ClassDef):
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.", stmt)
"a default value.",
stmt,
)
if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.", stmt)
"a docstring.",
stmt,
)
if isinstance(stmt.annotation, ast.Subscript) and \
isinstance(stmt.annotation.value, ast.Name) \
and stmt.annotation.value.id == "Union" and \
isinstance(stmt.annotation.slice, ast.Tuple):
if (
isinstance(stmt.annotation, ast.Subscript)
and isinstance(stmt.annotation.value, ast.Name)
and stmt.annotation.value.id == "Union"
and isinstance(stmt.annotation.slice, ast.Tuple)
):
args = stmt.annotation.slice.elts
literal_args = [
arg for arg in args
if isinstance(arg, ast.Subscript) and isinstance(
arg.value, ast.Name) and arg.value.id == "Literal"
arg
for arg in args
if isinstance(arg, ast.Subscript)
and isinstance(arg.value, ast.Name)
and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
@ -126,7 +140,9 @@ def validate_class(class_node: ast.ClassDef):
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".", stmt)
".",
stmt,
)
def validate_ast(tree: ast.stmt):