Release v4.0.0 (#2294)
This commit is contained in:
19
python/CuTeDSL/base_dsl/utils/__init__.py
Normal file
19
python/CuTeDSL/base_dsl/utils/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from . import stacktrace
|
||||
from . import logger
|
||||
from . import timer
|
||||
__all__ = [
|
||||
"logger",
|
||||
"timer",
|
||||
"stacktrace",
|
||||
]
|
||||
80
python/CuTeDSL/base_dsl/utils/logger.py
Normal file
80
python/CuTeDSL/base_dsl/utils/logger.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides logging helper functions
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = None
|
||||
|
||||
|
||||
def log():
|
||||
return logger
|
||||
|
||||
|
||||
def setup_log(
|
||||
name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1
|
||||
):
|
||||
"""Set up and configure a logger with console and/or file handlers.
|
||||
|
||||
:param name: Name of the logger to create
|
||||
:type name: str
|
||||
:param log_to_console: Whether to enable logging to console, defaults to False
|
||||
:type log_to_console: bool, optional
|
||||
:param log_to_file: Whether to enable logging to file, defaults to False
|
||||
:type log_to_file: bool, optional
|
||||
:param log_file_path: Path to the log file, required if log_to_file is True
|
||||
:type log_file_path: str, optional
|
||||
:param log_level: Logging level to set, defaults to 1
|
||||
:type log_level: int, optional
|
||||
:raises ValueError: If log_to_file is True but log_file_path is not provided
|
||||
:return: Configured logger instance
|
||||
:rtype: logging.Logger
|
||||
"""
|
||||
# Create a custom logger
|
||||
global logger
|
||||
logger = logging.getLogger(name)
|
||||
if log_to_console or log_to_file:
|
||||
logger.setLevel(log_level)
|
||||
else:
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
# Clear existing handlers to prevent duplicate logs
|
||||
if logger.hasHandlers():
|
||||
logger.handlers.clear()
|
||||
|
||||
# Define formatter
|
||||
formatter = logging.Formatter(
|
||||
f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s"
|
||||
)
|
||||
|
||||
# Add console handler if enabled
|
||||
if log_to_console:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# Add file handler if enabled
|
||||
if log_to_file:
|
||||
if not log_file_path:
|
||||
raise ValueError("log_file_path must be provided when enable_file is True")
|
||||
file_handler = logging.FileHandler(log_file_path)
|
||||
file_handler.setLevel(log_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_log("generic")
|
||||
165
python/CuTeDSL/base_dsl/utils/stacktrace.py
Normal file
165
python/CuTeDSL/base_dsl/utils/stacktrace.py
Normal file
@ -0,0 +1,165 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides stacktrace helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def walk_to_top_module(start_path):
|
||||
"""
|
||||
Walk up from the start_path to find the top-level Python module.
|
||||
|
||||
:param start_path: The path to start from.
|
||||
:return: The path of the top-level module.
|
||||
"""
|
||||
current_path = start_path
|
||||
|
||||
while True:
|
||||
# Check if we are at the root directory
|
||||
if os.path.dirname(current_path) == current_path:
|
||||
break
|
||||
|
||||
# Check for __init__.py
|
||||
init_file_path = os.path.join(current_path, "__init__.py")
|
||||
if os.path.isfile(init_file_path):
|
||||
# If __init__.py exists, move up one level
|
||||
current_path = os.path.dirname(current_path)
|
||||
else:
|
||||
# If no __init__.py, we are not in a module; stop
|
||||
break
|
||||
|
||||
# If we reached the root without finding a module, return None
|
||||
if os.path.dirname(current_path) == current_path and not os.path.isfile(
|
||||
os.path.join(current_path, "__init__.py")
|
||||
):
|
||||
return None
|
||||
|
||||
# Return the path of the top-level module
|
||||
return current_path
|
||||
|
||||
|
||||
def _filter_internal_frames(traceback, internal_path):
|
||||
"""
|
||||
Filter out stack frames from the traceback that belong to the specified module path.
|
||||
|
||||
This function removes stack frames from the traceback whose file paths start with
|
||||
the given prefix_path, effectively hiding internal implementation details from
|
||||
the error traceback shown to users.
|
||||
"""
|
||||
iter_prev = None
|
||||
iter_tb = traceback
|
||||
while iter_tb is not None:
|
||||
if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith(
|
||||
internal_path
|
||||
):
|
||||
if iter_tb.tb_next:
|
||||
if iter_prev:
|
||||
iter_prev.tb_next = iter_tb.tb_next
|
||||
else:
|
||||
traceback = iter_tb.tb_next
|
||||
else:
|
||||
iter_prev = iter_tb
|
||||
iter_tb = iter_tb.tb_next
|
||||
return traceback
|
||||
|
||||
|
||||
_generated_function_names = re.compile(
|
||||
r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$"
|
||||
)
|
||||
|
||||
|
||||
def _filter_duplicated_frames(traceback):
|
||||
"""
|
||||
Filter out duplicated stack frames from the traceback.
|
||||
The function filters out consecutive frames that are in the same file and have the same line number.
|
||||
In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame.
|
||||
"""
|
||||
iter_prev = None
|
||||
iter_tb = traceback
|
||||
while iter_tb is not None:
|
||||
skip_current = False
|
||||
skip_next = False
|
||||
if iter_tb.tb_next:
|
||||
current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename)
|
||||
next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename)
|
||||
# if in the same file, check if the line number is the same
|
||||
if current_filename == next_filename:
|
||||
current_lineno = iter_tb.tb_lineno
|
||||
next_lineno = iter_tb.tb_next.tb_lineno
|
||||
if current_lineno == next_lineno:
|
||||
# Same file and line number, check name, if current is generated, skip current, otherwise skip next
|
||||
name = iter_tb.tb_frame.f_code.co_name
|
||||
is_generated = bool(_generated_function_names.match(name))
|
||||
if is_generated:
|
||||
# Skip current
|
||||
skip_current = True
|
||||
else:
|
||||
# Skip next if it's generated, otherwise keep both
|
||||
next_name = iter_tb.tb_next.tb_frame.f_code.co_name
|
||||
skip_next = bool(_generated_function_names.match(next_name))
|
||||
if skip_current:
|
||||
if iter_prev:
|
||||
iter_prev.tb_next = iter_tb.tb_next
|
||||
else:
|
||||
traceback = iter_tb.tb_next
|
||||
elif skip_next:
|
||||
# if next is last frame, don't skip
|
||||
if iter_tb.tb_next.tb_next:
|
||||
iter_tb.tb_next = iter_tb.tb_next.tb_next
|
||||
iter_prev = iter_tb
|
||||
else:
|
||||
iter_prev = iter_tb
|
||||
iter_tb = iter_tb.tb_next
|
||||
|
||||
return traceback
|
||||
|
||||
|
||||
def filter_stackframe(traceback, prefix_path):
|
||||
"""
|
||||
Filter out stack frames from the traceback that belong to the specified module path.
|
||||
|
||||
This function removes stack frames from the traceback whose file paths start with
|
||||
the given prefix_path, effectively hiding internal implementation details from
|
||||
the error traceback shown to users.
|
||||
|
||||
:param traceback: The traceback object to filter.
|
||||
:param prefix_path: The path prefix to filter out from the traceback.
|
||||
:return: The filtered traceback with internal frames removed.
|
||||
"""
|
||||
# Step 1: filter internal frames
|
||||
traceback = _filter_internal_frames(traceback, prefix_path)
|
||||
|
||||
# Step 2: consolidate duplicated frames
|
||||
return _filter_duplicated_frames(traceback)
|
||||
|
||||
|
||||
def filter_exception(value, module_dir):
|
||||
"""
|
||||
Filter out internal implementation details from exception traceback.
|
||||
|
||||
This function recursively processes an exception and its cause chain,
|
||||
removing stack frames that belong to the specified module directory.
|
||||
This helps to present cleaner error messages to users by hiding
|
||||
implementation details.
|
||||
|
||||
:param value: The exception object to filter.
|
||||
:param module_dir: The module directory path to filter out from tracebacks.
|
||||
:return: The filtered exception with internal frames removed.
|
||||
"""
|
||||
if hasattr(value, "__cause__") and value.__cause__:
|
||||
filter_exception(value.__cause__, module_dir)
|
||||
|
||||
if hasattr(value, "__traceback__"):
|
||||
filter_stackframe(value.__traceback__, module_dir)
|
||||
56
python/CuTeDSL/base_dsl/utils/timer.py
Normal file
56
python/CuTeDSL/base_dsl/utils/timer.py
Normal file
@ -0,0 +1,56 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides a timing helper functions
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
from .logger import log
|
||||
|
||||
|
||||
# TODO: revisit this part when mlir timing manager is ready for pybind.
|
||||
def timer(*dargs, **kwargs):
|
||||
enable = kwargs.get("enable", True)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def func_wrapper(*args, **kwargs):
|
||||
if not enable:
|
||||
return func(*args, **kwargs)
|
||||
from time import time
|
||||
|
||||
start = time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time()
|
||||
|
||||
# Convert time from seconds to us
|
||||
spend_us = (end - start) * 1e6
|
||||
|
||||
# Determine the function type and format the log message
|
||||
if hasattr(func, "__name__"):
|
||||
func_name = func.__name__
|
||||
log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs"
|
||||
elif "CFunctionType" in str(type(func)):
|
||||
log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs"
|
||||
else:
|
||||
log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs"
|
||||
|
||||
log().info(log_message)
|
||||
|
||||
return result
|
||||
|
||||
return func_wrapper
|
||||
|
||||
if len(dargs) == 1 and callable(dargs[0]):
|
||||
return decorator(dargs[0])
|
||||
else:
|
||||
return decorator
|
||||
Reference in New Issue
Block a user