Import cuda, cudart, nvrtc lazily (#2251)
* Lazy cuda import * More lazy cuda import * More lazy cuda imports * minor fixes --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -29,7 +29,6 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -143,6 +142,7 @@ from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
|
||||
|
||||
this.memory_pool = None
|
||||
@ -156,10 +156,33 @@ def get_memory_pool():
|
||||
return this.memory_pool
|
||||
|
||||
|
||||
from cuda import cuda, cudart
|
||||
base_cuda = lazy_import("cuda")
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
this._device_id = None
|
||||
this._nvcc_version = None
|
||||
|
||||
def check_cuda_versions():
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = base_cuda.__version__.split("rc")[0]
|
||||
# Check that Python CUDA version exceeds NVCC version
|
||||
this._nvcc_version = nvcc_version()
|
||||
_cuda_list = _cuda_version.split('.')
|
||||
_nvcc_list = this._nvcc_version.split('.')
|
||||
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
|
||||
if int(val_cuda) < int(val_nvcc):
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
||||
|
||||
if len(_nvcc_list) > len(_cuda_list):
|
||||
if len(_nvcc_list) != len(_cuda_list) + 1:
|
||||
raise Exception(f"Malformatted NVCC version of {this._nvcc_version}")
|
||||
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
||||
|
||||
def initialize_cuda_context():
|
||||
check_cuda_versions()
|
||||
|
||||
if this._device_id is not None:
|
||||
return
|
||||
|
||||
|
||||
@ -33,7 +33,10 @@
|
||||
from math import prod
|
||||
from typing import Union
|
||||
|
||||
from cuda import cuda, cudart
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
import cutlass
|
||||
|
||||
@ -37,7 +37,10 @@ import sqlite3
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from cuda import cuda, nvrtc
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
nvrtc = lazy_import("cuda.nvrtc")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
|
||||
import cutlass
|
||||
|
||||
@ -29,11 +29,13 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -36,7 +36,8 @@ Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
|
||||
import ctypes
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import DataType
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -29,8 +29,10 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.memory_manager import device_mem_alloc, todevice
|
||||
|
||||
@ -29,12 +29,15 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import enum
|
||||
|
||||
from cuda import cuda, cudart
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -34,11 +34,12 @@ import numpy as np
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
|
||||
if cutlass.use_rmm:
|
||||
import rmm
|
||||
else:
|
||||
from cuda import cudart
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
|
||||
class PoolMemoryManager:
|
||||
|
||||
@ -31,16 +31,17 @@
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
from cuda import __version__, cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
|
||||
_supports_cluster_launch = None
|
||||
|
||||
|
||||
def supports_cluster_launch():
|
||||
from cuda import __version__
|
||||
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
|
||||
global _supports_cluster_launch
|
||||
if _supports_cluster_launch is None:
|
||||
major, minor = _version_splits[0], _version_splits[1]
|
||||
@ -79,10 +80,12 @@ class ExecutableOperation:
|
||||
def plan(self, arguments):
|
||||
raise NotImplementedError()
|
||||
|
||||
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)):
|
||||
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)):
|
||||
def run_with_clusters(self, launch_config, kernel_params, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
|
||||
attr = cuda.CUlaunchAttribute()
|
||||
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
|
||||
@ -110,7 +113,9 @@ class ExecutableOperation:
|
||||
config, f=self.kernel, kernelParams=kernel_params, extra=0)
|
||||
return err
|
||||
|
||||
def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)):
|
||||
def run_without_clusters(self, launch_config, kernel_params, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
err, = cuda.cuLaunchKernel(
|
||||
self.kernel,
|
||||
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
||||
@ -122,7 +127,9 @@ class ExecutableOperation:
|
||||
|
||||
return err
|
||||
|
||||
def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)):
|
||||
def run(self, host_workspace, device_workspace, launch_config, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
||||
packed = (ctypes.c_void_p * 1)()
|
||||
packed[0] = ctypes.addressof(cArg)
|
||||
|
||||
@ -29,11 +29,14 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cuda import cuda, cudart
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
from cutlass_library import (
|
||||
|
||||
@ -33,8 +33,11 @@
|
||||
"""
|
||||
Utility functions for interacting with the device
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from cuda import cuda, cudart
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
@ -37,7 +37,6 @@ Classes containing valid operations for a given compute capability and data type
|
||||
from itertools import combinations_with_replacement
|
||||
import logging
|
||||
|
||||
from cuda import __version__
|
||||
import cutlass_library
|
||||
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
|
||||
@ -48,23 +47,6 @@ from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
||||
|
||||
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
|
||||
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = __version__.split("rc")[0]
|
||||
|
||||
# Check that Python CUDA version exceeds NVCC version
|
||||
_nvcc_version = cutlass.nvcc_version()
|
||||
_cuda_list = _cuda_version.split('.')
|
||||
_nvcc_list = _nvcc_version.split('.')
|
||||
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
|
||||
if int(val_cuda) < int(val_nvcc):
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
|
||||
|
||||
if len(_nvcc_list) > len(_cuda_list):
|
||||
if len(_nvcc_list) != len(_cuda_list) + 1:
|
||||
raise Exception(f"Malformatted NVCC version of {_nvcc_version}")
|
||||
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
|
||||
|
||||
|
||||
class KernelsForDataType:
|
||||
"""
|
||||
@ -292,7 +274,7 @@ class ArchOptions:
|
||||
]
|
||||
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
||||
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
||||
generate_function(manifest, _nvcc_version)
|
||||
generate_function(manifest, cutlass._nvcc_version)
|
||||
|
||||
if operation_kind not in manifest.operations:
|
||||
# No kernels generated for this architecture, this could be because the CUDA
|
||||
|
||||
@ -111,8 +111,11 @@
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from cuda import cuda
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import (
|
||||
ConvKind,
|
||||
ConvMode,
|
||||
@ -735,7 +738,7 @@ class Conv2d(OperationBase):
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -768,6 +771,8 @@ class Conv2d(OperationBase):
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.Conv2dArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
@ -926,7 +931,10 @@ class Conv2dFprop(Conv2d):
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
@ -951,8 +959,11 @@ class Conv2dDgrad(Conv2d):
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
#
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
@ -976,8 +987,10 @@ class Conv2dWgrad(Conv2d):
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
#
|
||||
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
@ -113,10 +113,12 @@
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from math import prod
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
@ -623,7 +625,7 @@ class Gemm(OperationBase):
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments:
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -652,6 +654,8 @@ class Gemm(OperationBase):
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
super().run_setup()
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
|
||||
@ -50,10 +50,12 @@
|
||||
plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
|
||||
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_library import DataTypeSize
|
||||
|
||||
from cuda import cuda
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
@ -196,7 +198,7 @@ class GroupedGemm(Gemm):
|
||||
def run(self, A, B, C, D,
|
||||
alpha=None, beta=None, sync: bool = True,
|
||||
print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments:
|
||||
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
|
||||
"""
|
||||
Runs the kernel currently specified.
|
||||
|
||||
@ -225,6 +227,9 @@ class GroupedGemm(Gemm):
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmGroupedArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
super().run_setup()
|
||||
|
||||
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
|
||||
|
||||
11
python/cutlass/utils/lazy_import.py
Normal file
11
python/cutlass/utils/lazy_import.py
Normal file
@ -0,0 +1,11 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
def lazy_import(mod_name: str) -> Any:
|
||||
class Lazy:
|
||||
def __getattr__(self, name:str) -> Any:
|
||||
module = importlib.import_module(mod_name)
|
||||
return getattr(module, name)
|
||||
|
||||
return Lazy()
|
||||
|
||||
@ -37,7 +37,9 @@ Profiler based on the cuda events
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from cuda import cuda, cudart
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
from cutlass import CUTLASS_PATH
|
||||
@ -54,18 +56,27 @@ class GpuTimer:
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
]
|
||||
|
||||
def start(self, stream=cuda.CUstream(0)):
|
||||
def start(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
def stop(self, stream=cuda.CUstream(0)):
|
||||
def stop(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
pass
|
||||
|
||||
def stop_and_wait(self, stream=cuda.CUstream(0)):
|
||||
def stop_and_wait(self, stream=None):
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
self.stop(stream)
|
||||
if stream:
|
||||
(err,) = cuda.cuStreamSynchronize(stream)
|
||||
@ -182,4 +193,3 @@ class CUDAEventProfiler:
|
||||
flops_ += m * n * batch_count * 2
|
||||
|
||||
return flops_
|
||||
|
||||
|
||||
Reference in New Issue
Block a user