diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 583de7dd..ebda2ff2 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -29,7 +29,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# - import logging import os import sys @@ -143,6 +142,7 @@ from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad from cutlass.op.gemm_grouped import GroupedGemm from cutlass.op.op import OperationBase from cutlass.backend.evt.ir.tensor import Tensor +from cutlass.utils.lazy_import import lazy_import this.memory_pool = None @@ -156,10 +156,33 @@ def get_memory_pool(): return this.memory_pool -from cuda import cuda, cudart +base_cuda = lazy_import("cuda") +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") this._device_id = None +this._nvcc_version = None + +def check_cuda_versions(): + # Strip any additional information from the CUDA version + _cuda_version = base_cuda.__version__.split("rc")[0] + # Check that Python CUDA version exceeds NVCC version + this._nvcc_version = nvcc_version() + _cuda_list = _cuda_version.split('.') + _nvcc_list = this._nvcc_version.split('.') + for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): + if int(val_cuda) < int(val_nvcc): + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + + if len(_nvcc_list) > len(_cuda_list): + if len(_nvcc_list) != len(_cuda_list) + 1: + raise Exception(f"Malformatted NVCC version of {this._nvcc_version}") + if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + def initialize_cuda_context(): + check_cuda_versions() + if this._device_id is not None: return diff --git a/python/cutlass/backend/arguments.py b/python/cutlass/backend/arguments.py index eb31b762..7c2664e0 100644 --- a/python/cutlass/backend/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -33,7 +33,10 @@ from math import prod from typing import Union -from cuda import cuda, cudart +from cutlass.utils.lazy_import import lazy_import + +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") import numpy as np import cutlass diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index 43750d45..b4715602 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -37,7 +37,10 @@ import sqlite3 import subprocess import tempfile -from cuda import cuda, nvrtc +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +nvrtc = lazy_import("cuda.nvrtc") from cutlass_library import SubstituteTemplate import cutlass diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py index bf6e5754..a261ce90 100644 --- a/python/cutlass/backend/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -29,11 +29,13 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# +from __future__ import annotations import ctypes from typing import Union -from cuda import cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") from cutlass_library import SubstituteTemplate import numpy as np diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index 85c11bea..58bd5769 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -36,7 +36,8 @@ Epilogue Visitor interface for compiling, and running visitor-based epilogue. import ctypes -from cuda import cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") from cutlass_library import DataType import numpy as np diff --git a/python/cutlass/backend/frontend.py b/python/cutlass/backend/frontend.py index fe05582d..c1fb97c3 100644 --- a/python/cutlass/backend/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -29,8 +29,10 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# +from __future__ import annotations -from cuda import cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") import numpy as np from cutlass.backend.memory_manager import device_mem_alloc, todevice diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index f9d64148..0305abd5 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -29,12 +29,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# +from __future__ import annotations import copy import ctypes import enum -from cuda import cuda, cudart +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") from cutlass_library import SubstituteTemplate import numpy as np diff --git a/python/cutlass/backend/memory_manager.py b/python/cutlass/backend/memory_manager.py index 414af64d..9d7daf17 100644 --- a/python/cutlass/backend/memory_manager.py +++ b/python/cutlass/backend/memory_manager.py @@ -34,11 +34,12 @@ import numpy as np import cutlass from cutlass.utils.datatypes import is_numpy_tensor +from cutlass.utils.lazy_import import lazy_import if cutlass.use_rmm: import rmm else: - from cuda import cudart + cudart = lazy_import("cuda.cudart") class PoolMemoryManager: diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 7694941c..5b5400df 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -31,16 +31,17 @@ ################################################################################################# import ctypes - -from cuda import __version__, cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") from cutlass.backend.utils.device import device_cc -_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] _supports_cluster_launch = None def supports_cluster_launch(): + from cuda import __version__ + _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] global _supports_cluster_launch if _supports_cluster_launch is None: major, minor = _version_splits[0], _version_splits[1] @@ -79,10 +80,12 @@ class ExecutableOperation: def plan(self, arguments): raise NotImplementedError() - def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)): + def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None): raise NotImplementedError() - def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): + def run_with_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"): attr = cuda.CUlaunchAttribute() attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape @@ -110,7 +113,9 @@ class ExecutableOperation: config, f=self.kernel, kernelParams=kernel_params, extra=0) return err - def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): + def run_without_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) err, = cuda.cuLaunchKernel( self.kernel, launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], @@ -122,7 +127,9 @@ class ExecutableOperation: return err - def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): + def run(self, host_workspace, device_workspace, launch_config, stream=None): + if not stream: + stream = cuda.CUstream(0) cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) packed = (ctypes.c_void_p * 1)() packed[0] = ctypes.addressof(cArg) diff --git a/python/cutlass/backend/reduction_operation.py b/python/cutlass/backend/reduction_operation.py index 3aec9765..559d51c3 100644 --- a/python/cutlass/backend/reduction_operation.py +++ b/python/cutlass/backend/reduction_operation.py @@ -29,11 +29,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ +from __future__ import annotations import ctypes from typing import Union -from cuda import cuda, cudart +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") import numpy as np from cutlass_library import ( diff --git a/python/cutlass/backend/utils/device.py b/python/cutlass/backend/utils/device.py index 16c865b4..4f7620d9 100644 --- a/python/cutlass/backend/utils/device.py +++ b/python/cutlass/backend/utils/device.py @@ -33,8 +33,11 @@ """ Utility functions for interacting with the device """ +from __future__ import annotations -from cuda import cuda, cudart +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") import cutlass from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index aab798b1..9c6f0b39 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -37,7 +37,6 @@ Classes containing valid operations for a given compute capability and data type from itertools import combinations_with_replacement import logging -from cuda import __version__ import cutlass_library from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode @@ -48,23 +47,6 @@ from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op _generator_ccs = [50, 60, 61, 70, 75, 80, 90] -# Strip any additional information from the CUDA version -_cuda_version = __version__.split("rc")[0] - -# Check that Python CUDA version exceeds NVCC version -_nvcc_version = cutlass.nvcc_version() -_cuda_list = _cuda_version.split('.') -_nvcc_list = _nvcc_version.split('.') -for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): - if int(val_cuda) < int(val_nvcc): - raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") - -if len(_nvcc_list) > len(_cuda_list): - if len(_nvcc_list) != len(_cuda_list) + 1: - raise Exception(f"Malformatted NVCC version of {_nvcc_version}") - if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: - raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") - class KernelsForDataType: """ @@ -292,7 +274,7 @@ class ArchOptions: ] manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, _nvcc_version) + generate_function(manifest, cutlass._nvcc_version) if operation_kind not in manifest.operations: # No kernels generated for this architecture, this could be because the CUDA diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index c9fd8f9a..0e8366ab 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -111,8 +111,11 @@ args.sync() """ - -from cuda import cuda +from __future__ import annotations +from typing import Optional +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") from cutlass_library import ( ConvKind, ConvMode, @@ -735,7 +738,7 @@ class Conv2d(OperationBase): alpha=None, beta=None, split_k=("serial", 1), sync: bool = True, print_module: bool = False, - stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments: + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: """ Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the @@ -768,6 +771,8 @@ class Conv2d(OperationBase): :return: arguments passed in to the kernel :rtype: cutlass.backend.Conv2dArguments """ + if not stream: + stream = cuda.CUstream(0) super().run_setup() A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") @@ -926,7 +931,10 @@ class Conv2dFprop(Conv2d): self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), sync: bool = True, print_module: bool = False, - stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments: + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + + if not stream: + stream = cuda.CUstream(0) A, B, D = input, weight, output return super().run( @@ -951,8 +959,11 @@ class Conv2dDgrad(Conv2d): def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), sync: bool = True, print_module: bool = False, - stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments: + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: # + if not stream: + stream = cuda.CUstream(0) + A, B, D = grad_output, weight, grad_input return super().run( A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) @@ -976,8 +987,10 @@ class Conv2dWgrad(Conv2d): def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), sync: bool = True, print_module: bool = False, - stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments: - # + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + if not stream: + stream = cuda.CUstream(0) + A, B, D = grad_output, input, grad_weight return super().run( A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 9d4518e7..786c565b 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -113,10 +113,12 @@ args.sync() """ - +from __future__ import annotations +from typing import Optional from math import prod -from cuda import cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") from cutlass_library import ( DataType, DataTypeSize, @@ -623,7 +625,7 @@ class Gemm(OperationBase): def run(self, A=None, B=None, C=None, D=None, alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, - stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments: + stream: Optional[cuda.CUstream] = None) -> GemmArguments: """ Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the @@ -652,6 +654,8 @@ class Gemm(OperationBase): :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmArguments """ + if not stream: + stream = cuda.CUstream(0) super().run_setup() A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index c68747bc..da2fc8b9 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -50,10 +50,12 @@ plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) """ - +from __future__ import annotations +from typing import Optional from cutlass_library import DataTypeSize -from cuda import cuda +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") from cutlass.backend.gemm_operation import ( GemmGroupedArguments, GemmOperationGrouped, @@ -196,7 +198,7 @@ class GroupedGemm(Gemm): def run(self, A, B, C, D, alpha=None, beta=None, sync: bool = True, print_module: bool = False, - stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments: + stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments: """ Runs the kernel currently specified. @@ -225,6 +227,9 @@ class GroupedGemm(Gemm): :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmGroupedArguments """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): diff --git a/python/cutlass/utils/lazy_import.py b/python/cutlass/utils/lazy_import.py new file mode 100644 index 00000000..28ba6546 --- /dev/null +++ b/python/cutlass/utils/lazy_import.py @@ -0,0 +1,11 @@ +import importlib +from typing import Any + +def lazy_import(mod_name: str) -> Any: + class Lazy: + def __getattr__(self, name:str) -> Any: + module = importlib.import_module(mod_name) + return getattr(module, name) + + return Lazy() + diff --git a/python/cutlass/utils/profiler.py b/python/cutlass/utils/profiler.py index 87369670..155c1d35 100644 --- a/python/cutlass/utils/profiler.py +++ b/python/cutlass/utils/profiler.py @@ -37,7 +37,9 @@ Profiler based on the cuda events import re import subprocess -from cuda import cuda, cudart +from cutlass.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") import numpy as np from cutlass import CUTLASS_PATH @@ -54,18 +56,27 @@ class GpuTimer: cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], ] - def start(self, stream=cuda.CUstream(0)): + def start(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + (err,) = cuda.cuEventRecord(self.events[0], stream) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError(f"CUDA Error {str(err)}") - def stop(self, stream=cuda.CUstream(0)): + def stop(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + (err,) = cuda.cuEventRecord(self.events[1], stream) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError(f"CUDA Error {str(err)}") pass - def stop_and_wait(self, stream=cuda.CUstream(0)): + def stop_and_wait(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + self.stop(stream) if stream: (err,) = cuda.cuStreamSynchronize(stream) @@ -182,4 +193,3 @@ class CUDAEventProfiler: flops_ += m * n * batch_count * 2 return flops_ -