Files
cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py
2025-05-13 15:55:29 -04:00

202 lines
6.1 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
# Helpers
import itertools, operator
import ctypes
from . import dlpack_types as _dpack
from .dlpack_runtime import (
dlpack_to_tensor_desc,
get_tensor_desc_data_ptr,
get_tensor_desc_is_in_device,
get_tensor_desc_element_type,
get_tensor_desc_shape,
get_tensor_desc_stride,
get_tensor_desc_element_size_in_bytes,
get_tensor_desc_ndim,
get_tensor_desc_dtype_code,
get_tensor_desc_dtype_bits,
get_tensor_desc_device_type,
get_tensor_desc_device_id,
)
from ..utils.logger import log
from ..common import *
from ..typing import (
Boolean,
Float8E5M2,
Int64,
Int32,
Int16,
Int8,
Uint64,
Uint32,
Uint16,
Uint8,
Float64,
Float32,
Float16,
BFloat16,
)
class TensorDescriptor:
def __init__(self, tensor):
"""Initialize with a tensor that supports the DLPack protocol.
Args:
tensor: Any tensor object that implements __dlpack__ and __dlpack_device__
"""
self.tensor = tensor
self._capsule = dlpack_to_tensor_desc(tensor)
self.data_ptr = get_tensor_desc_data_ptr(self._capsule)
self.device_type = get_tensor_desc_device_type(self._capsule)
self.device_type = _dpack.DLDeviceType(self.device_type)
if self.device_type == _dpack.DLDeviceType.kDLGPU:
self.device_pointer = self.data_ptr
elif self.device_type == _dpack.DLDeviceType.kDLCPU:
self.device_pointer = None
else:
raise DSLRuntimeError(
f"DLPack device type is not supported {self.dl_tensor.device.device_type}"
)
log().info("TensorDescriptor is created = [%s]", self)
@staticmethod
def can_transformed_to_dlpack(dl_tensor):
if not hasattr(dl_tensor, "__dlpack__") or not hasattr(
dl_tensor, "__dlpack_device__"
):
return False
return True
@property
def is_in_device(self):
"""Check if the tensor is stored on a device."""
return not self.device_pointer is None
@property
def device_id(self):
"""Return device id where tensor resides."""
if self.is_in_device:
return get_tensor_desc_device_id(self._capsule)
return -1
@property
def element_type(self):
"""Return the corresponding Python type based on DLPack dtype metadata."""
str_element_type = get_tensor_desc_element_type(self._capsule)
dtype_map = {
# bool is 8bit from numpy and torch
"Bool": Boolean,
"Int64": Int64,
"Int32": Int32,
"Int16": Int16,
"Int8": Int8,
"UInt64": Uint64,
"UInt32": Uint32,
"UInt16": Uint16,
"UInt8": Uint8,
"Float64": Float64,
"Float32": Float32,
"Float16": Float16,
"BFloat16": BFloat16,
"Float8E5M2": Float8E5M2,
}
if str_element_type not in dtype_map:
raise KeyError(
f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}"
)
return dtype_map[str_element_type]
@property
def shape(self):
"""Return the shape of the tensor."""
return get_tensor_desc_shape(self._capsule)
@property
def rank(self):
"""Return the rank of the tensor."""
return get_tensor_desc_ndim(self._capsule)
@property
def strides(self):
"""Return the rank of the tensor."""
return get_tensor_desc_stride(self._capsule)
@property
def element_size_in_bytes(self):
"""Calculate the element size in bytes of the DLPack tensor."""
return get_tensor_desc_element_size_in_bytes(self._capsule)
@property
def size_in_bytes(self):
"""Calculate the total size in bytes of the DLPack tensor."""
# Calculate the number of elements using the shape
ndim = get_tensor_desc_ndim(self._capsule)
shape = get_tensor_desc_shape(self._capsule)
num_elements = 1
for i in range(ndim):
num_elements *= shape[i]
# Total bytes
total_bytes = self.element_size_in_bytes * num_elements
return total_bytes
def __str__(self):
"""Return a compact string representation of the device_tensor with a tensor prefix."""
# Extract shape
shape = "x".join(map(str, self.shape))
# Extract dtype
dtype_code = get_tensor_desc_dtype_code(self._capsule)
dtype_bits = get_tensor_desc_dtype_bits(self._capsule)
dtype = (
f"i{dtype_bits}"
if dtype_code == _dpack.DLDataTypeCode.kDLInt
else f"f{dtype_bits}"
)
# Extract device
device_type = "cpu" if not self.is_in_device else "gpu"
return f"tensor<{shape}x{dtype}>_{device_type}"
def _check_is_managed_by_framework(self):
"""
Ensure the tensor is not managed by the framework (e.g., GPU tensor).
Raises an exception if the tensor is framework-managed.
"""
return self.device_type == _dpack.DLDeviceType.kDLGPU
def from_tensor(tensor) -> TensorDescriptor:
"""Create a TensorDescriptor from a tensor object."""
return TensorDescriptor(tensor)
def to_tensor(tensor_descriptor: TensorDescriptor):
"""Return tensor object from tensor descriptor."""
return tensor_descriptor.tensor
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)