Add required changes for github pipeline. (#2648)

This commit is contained in:
Junkai-Wu
2025-09-18 10:22:45 +08:00
committed by GitHub
parent 7817e47154
commit 8825e8be4f
32 changed files with 343 additions and 1 deletions

View File

@ -124,7 +124,7 @@ JIT function arguments with |CUSTOM_TYPES|
- ``__extract_mlir_values__``: Generate a dynamic expression for the current object.
- ``__new_from_mlir_values__``: Create a new object from MLIR values.
Refer to `typing.py <https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL/base_dsl/typing.py>`__ for more details on these protocol APIs.
Refer to `typing.py <https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL/cutlass/base_dsl/typing.py>`__ for more details on these protocol APIs.
Depending on different cases of the |CUSTOM_TYPES|, |DSL| provides easy ways to adopt |CUSTOM_TYPES| for JIT function arguments.

View File

@ -0,0 +1,310 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
CuTeDSL Development Package Setup
This setup script automatically downloads the nvidia-cutlass-dsl wheel,
extracts required libraries and Python packages, and sets up the development
environment for CuTeDSL.
"""
import subprocess
import sys
import shutil
import tempfile
import zipfile
import re
from pathlib import Path
from typing import Optional, Tuple, List
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
# Constants
PACKAGE_NAME = "nvidia-cutlass-dsl"
class CutlassDSLSetupError(Exception):
"""Custom exception for setup errors."""
pass
def download_wheel(temp_dir: Path) -> Path:
"""
Download the nvidia-cutlass-dsl wheel to a temporary directory.
Args:
temp_dir: Temporary directory path for downloading
Returns:
Path to the downloaded wheel file
Raises:
CutlassDSLSetupError: If download fails or wheel not found
"""
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
try:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"download",
"--no-deps",
PACKAGE_NAME,
"--dest",
str(temp_dir),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
error_msg = f"Failed to download {PACKAGE_NAME}: {e}"
if e.stdout:
error_msg += f"\nstdout: {e.stdout.decode()}"
if e.stderr:
error_msg += f"\nstderr: {e.stderr.decode()}"
raise CutlassDSLSetupError(error_msg)
# Find the downloaded wheel file
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
wheel_files = list(temp_dir.glob(wheel_pattern))
if not wheel_files:
raise CutlassDSLSetupError(
f"No wheel file matching {wheel_pattern} found after download"
)
wheel_path = wheel_files[0]
logger.info(f"Successfully downloaded: {wheel_path.name}")
return wheel_path
def extract_version_from_wheel(wheel_path: Path) -> str:
"""
Extract version from wheel filename and convert to dev version.
Args:
wheel_path: Path to the wheel file
Returns:
Version string in format '{version}.dev0' or '{base_version}.dev{n+1}' if already has dev{n}
Raises:
CutlassDSLSetupError: If version cannot be extracted from filename
"""
wheel_filename = wheel_path.name
# Construct version regex from package name
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
package_pattern = PACKAGE_NAME.replace("-", "_")
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
version_match = re.match(version_regex, wheel_filename)
if version_match:
version = version_match.group(1)
# Check if version already has .dev<n> pattern
dev_pattern = r"^(.+)\.dev(\d+)"
dev_match = re.match(dev_pattern, version)
if dev_match:
base_version = dev_match.group(1)
dev_number = int(dev_match.group(2))
new_dev_number = dev_number + 1
dev_version = f"{base_version}.dev{new_dev_number}"
logger.info(
f"Detected version with dev{dev_number}: {version} -> using {dev_version}"
)
else:
dev_version = f"{version}.dev0"
logger.info(f"Detected version: {version} -> using {dev_version}")
return dev_version
else:
raise CutlassDSLSetupError(
f"Could not parse version from wheel filename: {wheel_filename}"
)
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
"""
Extract wheel contents to specified directory.
Args:
wheel_path: Path to the wheel file
extract_dir: Directory to extract contents to
Raises:
CutlassDSLSetupError: If extraction fails
"""
logger.info(f"Extracting wheel contents to {extract_dir}")
try:
with zipfile.ZipFile(wheel_path, "r") as wheel_zip:
wheel_zip.extractall(extract_dir)
logger.info("Wheel extraction completed successfully")
except zipfile.BadZipFile as e:
raise CutlassDSLSetupError(f"Invalid wheel file {wheel_path}: {e}")
except Exception as e:
raise CutlassDSLSetupError(f"Failed to extract wheel: {e}")
def copy_library_files(extract_dir: Path, package_root: Path) -> int:
"""
Copy .so library files from extracted wheel to package lib directory.
Args:
extract_dir: Directory containing extracted wheel contents
package_root: Root directory of the package
Returns:
Number of files copied
"""
lib_pattern = extract_dir / "**" / "lib" / "*.so"
so_files = [f for f in extract_dir.rglob("lib/*.so")]
if not so_files:
logger.warning("No .so files found in the wheel")
return 0
logger.info(f"Found {len(so_files)} .so files")
# Create lib directory
lib_dir = package_root / "lib"
lib_dir.mkdir(exist_ok=True)
# Copy .so files
copied_count = 0
for so_file in so_files:
dest_path = lib_dir / so_file.name
logger.info(f"Copying {so_file.name} to {dest_path}")
shutil.copy2(so_file, dest_path)
copied_count += 1
logger.info(f"Successfully copied {copied_count} .so files to lib/")
return copied_count
def copy_python_packages(extract_dir: Path, package_root: Path) -> Tuple[int, int]:
"""
Copy python_packages/cutlass/ directory to local cutlass/ directory.
Ignores conflicts with existing files.
Args:
extract_dir: Directory containing extracted wheel contents
package_root: Root directory of the package
Returns:
Tuple of (files_copied, files_skipped)
"""
# Find source cutlass directory
cutlass_source_dirs = list(extract_dir.rglob("python_packages/cutlass"))
if not cutlass_source_dirs:
logger.warning("No python_packages/cutlass/ directory found in the wheel")
return 0, 0
cutlass_source_dir = cutlass_source_dirs[0]
cutlass_dest_dir = package_root / "cutlass"
logger.info(f"Found python_packages/cutlass/ directory")
logger.info(f"Copying from {cutlass_source_dir} to {cutlass_dest_dir}")
copied_count = 0
skipped_count = 0
# Walk through source directory
for src_file in cutlass_source_dir.rglob("*"):
if src_file.is_file():
# Calculate relative path and destination
rel_path = src_file.relative_to(cutlass_source_dir)
dest_file = cutlass_dest_dir / rel_path
# Create parent directories
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Copy file if it doesn't exist
if dest_file.exists():
skipped_count += 1
logger.debug(f" Skipping {rel_path} (already exists)")
else:
shutil.copy2(src_file, dest_file)
copied_count += 1
logger.info(f" Copied {rel_path}")
logger.info(
f"Cutlass directory update: {copied_count} files copied, {skipped_count} files skipped"
)
return copied_count, skipped_count
def write_version_file(version: str, package_root: Path) -> None:
"""
Write version string to VERSION file in the package root directory.
Args:
version: Version string to write
package_root: Root directory of the package
"""
version_file = package_root / "VERSION.EDITABLE"
logger.info(f"Writing version {version} to {version_file}")
try:
with open(version_file, "w", encoding="utf-8") as f:
f.write(version + "\n")
logger.info(f"Successfully created VERSION file with version: {version}")
except Exception as e:
raise CutlassDSLSetupError(f"Failed to write VERSION file: {e}")
def prep_editable_install() -> None:
"""
Set up the CuTeDSL development environment.
Downloads nvidia-cutlass-dsl wheel, extracts version, and copies required files.
Raises:
CutlassDSLSetupError: If setup fails
"""
package_root = Path(__file__).parent
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
extract_dir = temp_dir / "extracted"
# Download and extract wheel
wheel_path = download_wheel(temp_dir)
version = extract_version_from_wheel(wheel_path)
extract_wheel_contents(wheel_path, extract_dir)
# Copy files
lib_files_copied = copy_library_files(extract_dir, package_root)
py_files_copied, py_files_skipped = copy_python_packages(
extract_dir, package_root
)
# Write version file
write_version_file(version, package_root)
logger.info("Setup completed successfully!")
logger.info(
f"Summary: {lib_files_copied} lib files, "
f"{py_files_copied} Python files copied, "
f"{py_files_skipped} Python files skipped"
)
logger.info(f"Detected upstream version: {version}")
if __name__ == "__main__":
prep_editable_install()

View File

@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "nvidia-cutlass-dsl"
dynamic = ["version"]
description = "NVIDIA CUTLASS Python DSL"
authors = [
{name = "NVIDIA Corporation"},
]
[tool.setuptools]
packages = ["cutlass"]
include-package-data = true
[tool.setuptools.dynamic]
version = {file = "VERSION.EDITABLE"}
[tool.setuptools.package-data]
nvidia_cutlass_dsl = ["lib/**/*"]