Add required changes for github pipeline. (#2648)
This commit is contained in:
@ -124,7 +124,7 @@ JIT function arguments with |CUSTOM_TYPES|
|
||||
- ``__extract_mlir_values__``: Generate a dynamic expression for the current object.
|
||||
- ``__new_from_mlir_values__``: Create a new object from MLIR values.
|
||||
|
||||
Refer to `typing.py <https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL/base_dsl/typing.py>`__ for more details on these protocol APIs.
|
||||
Refer to `typing.py <https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL/cutlass/base_dsl/typing.py>`__ for more details on these protocol APIs.
|
||||
|
||||
Depending on different cases of the |CUSTOM_TYPES|, |DSL| provides easy ways to adopt |CUSTOM_TYPES| for JIT function arguments.
|
||||
|
||||
|
||||
310
python/CuTeDSL/prep_editable_install.py
Normal file
310
python/CuTeDSL/prep_editable_install.py
Normal file
@ -0,0 +1,310 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
"""
|
||||
CuTeDSL Development Package Setup
|
||||
|
||||
This setup script automatically downloads the nvidia-cutlass-dsl wheel,
|
||||
extracts required libraries and Python packages, and sets up the development
|
||||
environment for CuTeDSL.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
PACKAGE_NAME = "nvidia-cutlass-dsl"
|
||||
|
||||
|
||||
class CutlassDSLSetupError(Exception):
|
||||
"""Custom exception for setup errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def download_wheel(temp_dir: Path) -> Path:
|
||||
"""
|
||||
Download the nvidia-cutlass-dsl wheel to a temporary directory.
|
||||
|
||||
Args:
|
||||
temp_dir: Temporary directory path for downloading
|
||||
|
||||
Returns:
|
||||
Path to the downloaded wheel file
|
||||
|
||||
Raises:
|
||||
CutlassDSLSetupError: If download fails or wheel not found
|
||||
"""
|
||||
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"download",
|
||||
"--no-deps",
|
||||
PACKAGE_NAME,
|
||||
"--dest",
|
||||
str(temp_dir),
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"Failed to download {PACKAGE_NAME}: {e}"
|
||||
if e.stdout:
|
||||
error_msg += f"\nstdout: {e.stdout.decode()}"
|
||||
if e.stderr:
|
||||
error_msg += f"\nstderr: {e.stderr.decode()}"
|
||||
raise CutlassDSLSetupError(error_msg)
|
||||
|
||||
# Find the downloaded wheel file
|
||||
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
|
||||
wheel_files = list(temp_dir.glob(wheel_pattern))
|
||||
if not wheel_files:
|
||||
raise CutlassDSLSetupError(
|
||||
f"No wheel file matching {wheel_pattern} found after download"
|
||||
)
|
||||
|
||||
wheel_path = wheel_files[0]
|
||||
logger.info(f"Successfully downloaded: {wheel_path.name}")
|
||||
return wheel_path
|
||||
|
||||
|
||||
def extract_version_from_wheel(wheel_path: Path) -> str:
|
||||
"""
|
||||
Extract version from wheel filename and convert to dev version.
|
||||
|
||||
Args:
|
||||
wheel_path: Path to the wheel file
|
||||
|
||||
Returns:
|
||||
Version string in format '{version}.dev0' or '{base_version}.dev{n+1}' if already has dev{n}
|
||||
|
||||
Raises:
|
||||
CutlassDSLSetupError: If version cannot be extracted from filename
|
||||
"""
|
||||
wheel_filename = wheel_path.name
|
||||
# Construct version regex from package name
|
||||
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
|
||||
package_pattern = PACKAGE_NAME.replace("-", "_")
|
||||
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
|
||||
version_match = re.match(version_regex, wheel_filename)
|
||||
|
||||
if version_match:
|
||||
version = version_match.group(1)
|
||||
|
||||
# Check if version already has .dev<n> pattern
|
||||
dev_pattern = r"^(.+)\.dev(\d+)"
|
||||
dev_match = re.match(dev_pattern, version)
|
||||
|
||||
if dev_match:
|
||||
base_version = dev_match.group(1)
|
||||
dev_number = int(dev_match.group(2))
|
||||
new_dev_number = dev_number + 1
|
||||
dev_version = f"{base_version}.dev{new_dev_number}"
|
||||
logger.info(
|
||||
f"Detected version with dev{dev_number}: {version} -> using {dev_version}"
|
||||
)
|
||||
else:
|
||||
dev_version = f"{version}.dev0"
|
||||
logger.info(f"Detected version: {version} -> using {dev_version}")
|
||||
|
||||
return dev_version
|
||||
else:
|
||||
raise CutlassDSLSetupError(
|
||||
f"Could not parse version from wheel filename: {wheel_filename}"
|
||||
)
|
||||
|
||||
|
||||
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
|
||||
"""
|
||||
Extract wheel contents to specified directory.
|
||||
|
||||
Args:
|
||||
wheel_path: Path to the wheel file
|
||||
extract_dir: Directory to extract contents to
|
||||
|
||||
Raises:
|
||||
CutlassDSLSetupError: If extraction fails
|
||||
"""
|
||||
logger.info(f"Extracting wheel contents to {extract_dir}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(wheel_path, "r") as wheel_zip:
|
||||
wheel_zip.extractall(extract_dir)
|
||||
logger.info("Wheel extraction completed successfully")
|
||||
except zipfile.BadZipFile as e:
|
||||
raise CutlassDSLSetupError(f"Invalid wheel file {wheel_path}: {e}")
|
||||
except Exception as e:
|
||||
raise CutlassDSLSetupError(f"Failed to extract wheel: {e}")
|
||||
|
||||
|
||||
def copy_library_files(extract_dir: Path, package_root: Path) -> int:
|
||||
"""
|
||||
Copy .so library files from extracted wheel to package lib directory.
|
||||
|
||||
Args:
|
||||
extract_dir: Directory containing extracted wheel contents
|
||||
package_root: Root directory of the package
|
||||
|
||||
Returns:
|
||||
Number of files copied
|
||||
"""
|
||||
lib_pattern = extract_dir / "**" / "lib" / "*.so"
|
||||
so_files = [f for f in extract_dir.rglob("lib/*.so")]
|
||||
|
||||
if not so_files:
|
||||
logger.warning("No .so files found in the wheel")
|
||||
return 0
|
||||
|
||||
logger.info(f"Found {len(so_files)} .so files")
|
||||
|
||||
# Create lib directory
|
||||
lib_dir = package_root / "lib"
|
||||
lib_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy .so files
|
||||
copied_count = 0
|
||||
for so_file in so_files:
|
||||
dest_path = lib_dir / so_file.name
|
||||
logger.info(f"Copying {so_file.name} to {dest_path}")
|
||||
shutil.copy2(so_file, dest_path)
|
||||
copied_count += 1
|
||||
|
||||
logger.info(f"Successfully copied {copied_count} .so files to lib/")
|
||||
return copied_count
|
||||
|
||||
|
||||
def copy_python_packages(extract_dir: Path, package_root: Path) -> Tuple[int, int]:
|
||||
"""
|
||||
Copy python_packages/cutlass/ directory to local cutlass/ directory.
|
||||
Ignores conflicts with existing files.
|
||||
|
||||
Args:
|
||||
extract_dir: Directory containing extracted wheel contents
|
||||
package_root: Root directory of the package
|
||||
|
||||
Returns:
|
||||
Tuple of (files_copied, files_skipped)
|
||||
"""
|
||||
# Find source cutlass directory
|
||||
cutlass_source_dirs = list(extract_dir.rglob("python_packages/cutlass"))
|
||||
|
||||
if not cutlass_source_dirs:
|
||||
logger.warning("No python_packages/cutlass/ directory found in the wheel")
|
||||
return 0, 0
|
||||
|
||||
cutlass_source_dir = cutlass_source_dirs[0]
|
||||
cutlass_dest_dir = package_root / "cutlass"
|
||||
|
||||
logger.info(f"Found python_packages/cutlass/ directory")
|
||||
logger.info(f"Copying from {cutlass_source_dir} to {cutlass_dest_dir}")
|
||||
|
||||
copied_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# Walk through source directory
|
||||
for src_file in cutlass_source_dir.rglob("*"):
|
||||
if src_file.is_file():
|
||||
# Calculate relative path and destination
|
||||
rel_path = src_file.relative_to(cutlass_source_dir)
|
||||
dest_file = cutlass_dest_dir / rel_path
|
||||
|
||||
# Create parent directories
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy file if it doesn't exist
|
||||
if dest_file.exists():
|
||||
skipped_count += 1
|
||||
logger.debug(f" Skipping {rel_path} (already exists)")
|
||||
else:
|
||||
shutil.copy2(src_file, dest_file)
|
||||
copied_count += 1
|
||||
logger.info(f" Copied {rel_path}")
|
||||
|
||||
logger.info(
|
||||
f"Cutlass directory update: {copied_count} files copied, {skipped_count} files skipped"
|
||||
)
|
||||
return copied_count, skipped_count
|
||||
|
||||
|
||||
def write_version_file(version: str, package_root: Path) -> None:
|
||||
"""
|
||||
Write version string to VERSION file in the package root directory.
|
||||
|
||||
Args:
|
||||
version: Version string to write
|
||||
package_root: Root directory of the package
|
||||
"""
|
||||
version_file = package_root / "VERSION.EDITABLE"
|
||||
logger.info(f"Writing version {version} to {version_file}")
|
||||
|
||||
try:
|
||||
with open(version_file, "w", encoding="utf-8") as f:
|
||||
f.write(version + "\n")
|
||||
logger.info(f"Successfully created VERSION file with version: {version}")
|
||||
except Exception as e:
|
||||
raise CutlassDSLSetupError(f"Failed to write VERSION file: {e}")
|
||||
|
||||
|
||||
def prep_editable_install() -> None:
|
||||
"""
|
||||
Set up the CuTeDSL development environment.
|
||||
|
||||
Downloads nvidia-cutlass-dsl wheel, extracts version, and copies required files.
|
||||
|
||||
Raises:
|
||||
CutlassDSLSetupError: If setup fails
|
||||
"""
|
||||
package_root = Path(__file__).parent
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
extract_dir = temp_dir / "extracted"
|
||||
|
||||
# Download and extract wheel
|
||||
wheel_path = download_wheel(temp_dir)
|
||||
version = extract_version_from_wheel(wheel_path)
|
||||
extract_wheel_contents(wheel_path, extract_dir)
|
||||
|
||||
# Copy files
|
||||
lib_files_copied = copy_library_files(extract_dir, package_root)
|
||||
py_files_copied, py_files_skipped = copy_python_packages(
|
||||
extract_dir, package_root
|
||||
)
|
||||
|
||||
# Write version file
|
||||
write_version_file(version, package_root)
|
||||
|
||||
logger.info("Setup completed successfully!")
|
||||
logger.info(
|
||||
f"Summary: {lib_files_copied} lib files, "
|
||||
f"{py_files_copied} Python files copied, "
|
||||
f"{py_files_skipped} Python files skipped"
|
||||
)
|
||||
logger.info(f"Detected upstream version: {version}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prep_editable_install()
|
||||
32
python/CuTeDSL/pyproject.toml
Normal file
32
python/CuTeDSL/pyproject.toml
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "nvidia-cutlass-dsl"
|
||||
dynamic = ["version"]
|
||||
description = "NVIDIA CUTLASS Python DSL"
|
||||
authors = [
|
||||
{name = "NVIDIA Corporation"},
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["cutlass"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {file = "VERSION.EDITABLE"}
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
nvidia_cutlass_dsl = ["lib/**/*"]
|
||||
Reference in New Issue
Block a user