v4.2 tag release. (#2638)
This commit is contained in:
@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for
|
||||
the DSL.
|
||||
"""
|
||||
|
||||
from . import device_tensor
|
||||
from . import dlpack_types
|
||||
from . import cuda
|
||||
from . import tensor_descriptor
|
||||
from . import jit_arg_adapters
|
||||
|
||||
__all__ = [
|
||||
"device_tensor",
|
||||
"dlpack_types",
|
||||
"cuda",
|
||||
"tensor_descriptor",
|
||||
"jit_arg_adapters",
|
||||
]
|
||||
|
||||
@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name):
|
||||
return kernel
|
||||
|
||||
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
|
||||
"""
|
||||
Launches the CUDA kernel.
|
||||
"""
|
||||
|
||||
@ -183,6 +183,13 @@ class TensorDescriptor:
|
||||
"""
|
||||
return self.device_type == _dpack.DLDeviceType.kDLGPU
|
||||
|
||||
@staticmethod
|
||||
def is_compatible(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor or can be converted to one."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
|
||||
def from_tensor(tensor) -> TensorDescriptor:
|
||||
"""Create a TensorDescriptor from a tensor object."""
|
||||
@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor:
|
||||
def to_tensor(tensor_descriptor: TensorDescriptor):
|
||||
"""Return tensor object from tensor descriptor."""
|
||||
return tensor_descriptor.tensor
|
||||
|
||||
|
||||
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
Reference in New Issue
Block a user