306 lines
10 KiB
Python
306 lines
10 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
|
|
|
|
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
|
|
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
|
|
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
|
|
|
|
The C-structure is defined as:
|
|
|
|
.. code-block:: c
|
|
|
|
struct Tensor {
|
|
void *ptr; // Pointer to tensor data
|
|
int32_t shape[3]; // Tensor dimensions
|
|
int32_t strides[3]; // Memory strides for each dimension
|
|
};
|
|
|
|
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
|
|
shape, and strides, enabling efficient data passing between different language boundaries.
|
|
|
|
.. note::
|
|
Future development may include automated code generation flows.
|
|
"""
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
from cutlass._mlir import ir
|
|
from cutlass._mlir.dialects import llvm
|
|
import cutlass._mlir.extras.types as T
|
|
|
|
|
|
class ExampleTensorValue(ir.Value):
|
|
"""A wrapper class for tensor values in MLIR.
|
|
|
|
This class extends ir.Value to provide convenient access to tensor data pointer,
|
|
shape, and strides through MLIR operations.
|
|
|
|
:type: ir.Value
|
|
"""
|
|
|
|
def __init__(self, v):
|
|
"""Initialize a new TensorValue.
|
|
|
|
:param v: The underlying MLIR value to wrap
|
|
:type v: ir.Value
|
|
"""
|
|
super().__init__(v)
|
|
|
|
@property
|
|
def data_ptr(self, *, loc=None, ip=None):
|
|
"""Get the data pointer from the tensor value.
|
|
|
|
Extracts the data pointer (first field) from the LLVM struct value.
|
|
|
|
:param loc: Optional location information for MLIR operations
|
|
:type loc: Optional[ir.Location]
|
|
:param ip: Optional insertion point for MLIR operations
|
|
:type ip: Optional[ir.InsertionPoint]
|
|
:return: An integer value representing the data pointer
|
|
:rtype: ir.Value
|
|
"""
|
|
# Extract the data pointer from the LLVM struct value
|
|
# The data pointer is the first field (index 0) in the struct
|
|
|
|
# Use llvm.extractvalue to get the pointer field from the struct
|
|
ptr_val = llvm.extractvalue(
|
|
llvm.PointerType.get(),
|
|
self,
|
|
[0], # Extract the first field (index 0)
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
|
|
return cute.make_ptr(cutlass.Float32, ptr_val)
|
|
|
|
@property
|
|
def shape(self):
|
|
"""Get the shape of the tensor.
|
|
|
|
Extracts the shape (second field) from the LLVM struct value.
|
|
|
|
:return: A tuple of integers representing the tensor dimensions
|
|
:rtype: tuple[ir.Value, ...]
|
|
"""
|
|
i32_type = ir.IntegerType.get_signless(32)
|
|
# Extract the shape field from the LLVM struct value
|
|
# The shape is the second field (index 1) in the struct
|
|
shape_val = llvm.extractvalue(
|
|
llvm.StructType.get_literal([i32_type] * 3),
|
|
self,
|
|
[1], # Extract the second field (index 1)
|
|
)
|
|
|
|
# Extract each dimension from the shape struct
|
|
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
|
|
|
|
@property
|
|
def stride(self):
|
|
"""Get the strides of the tensor.
|
|
|
|
Extracts the strides (third field) from the LLVM struct value.
|
|
|
|
:return: A tuple of integers representing the tensor strides
|
|
:rtype: tuple[ir.Value, ...]
|
|
"""
|
|
i32_type = ir.IntegerType.get_signless(32)
|
|
# Extract the strides field from the LLVM struct value
|
|
# The strides are the third field (index 2) in the struct
|
|
strides_val = llvm.extractvalue(
|
|
llvm.StructType.get_literal([i32_type] * 3),
|
|
self,
|
|
[2], # Extract the third field (index 2)
|
|
)
|
|
|
|
# Extract each dimension from the strides struct
|
|
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
|
|
|
|
|
|
class ExampleTensor:
|
|
"""A class representing a tensor with its data pointer, shape, and strides.
|
|
|
|
This class provides a Python interface to create and manipulate tensor structures
|
|
that can be passed to CUTE JIT compiled functions.
|
|
|
|
:ivar _c_struct_p: The C struct pointer for the tensor
|
|
:ivar _rank: The number of dimensions in the tensor
|
|
"""
|
|
|
|
def __init__(self, c_struct_p, rank):
|
|
"""Initialize a new Tensor.
|
|
|
|
:param c_struct_p: The C struct pointer for the tensor
|
|
:type c_struct_p: int
|
|
:param rank: The number of dimensions in the tensor
|
|
:type rank: int
|
|
"""
|
|
self._c_struct_p = c_struct_p
|
|
self._rank = rank
|
|
|
|
def __get_mlir_types__(self):
|
|
"""Get the MLIR types for this tensor.
|
|
|
|
Creates an LLVM structure type representing a C-structure with:
|
|
|
|
.. code-block:: c
|
|
|
|
struct Tensor {
|
|
void *ptr;
|
|
int32_t shape[3];
|
|
int32_t strides[3];
|
|
};
|
|
|
|
:return: A list containing the MLIR struct type
|
|
:rtype: list[llvm.StructType]
|
|
|
|
Create an LLVM structure type that represents a C-structure like:
|
|
"""
|
|
|
|
# Get the number of dimensions from the shape
|
|
ndim = self._rank
|
|
|
|
# Create the pointer type (void*)
|
|
ptr_type = llvm.PointerType.get()
|
|
|
|
# Create array types for shape and strides (int32_t[ndim])
|
|
int32_type = ir.IntegerType.get_signless(32)
|
|
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
|
|
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
|
|
|
|
# Create the structure type
|
|
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
|
|
|
|
return [struct_type]
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
"""Create a new TensorValue from MLIR values.
|
|
|
|
:param values: A list of MLIR values
|
|
:type values: list[ir.Value]
|
|
:return: A new TensorValue instance
|
|
:rtype: TensorValue
|
|
"""
|
|
return ExampleTensorValue(values[0])
|
|
|
|
def __c_pointers__(self):
|
|
"""Get the C pointers for this tensor.
|
|
|
|
:return: A list containing the C struct pointer
|
|
:rtype: list[int]
|
|
"""
|
|
return [self._c_struct_p]
|
|
|
|
|
|
@cute.jit
|
|
def foo(tensor):
|
|
"""Example JIT function that prints tensor information.
|
|
|
|
:param tensor: A Tensor instance to print information about
|
|
:type tensor: Tensor
|
|
"""
|
|
cute.printf("data_ptr: {}", tensor.data_ptr)
|
|
cute.printf("shape: {}", tensor.shape)
|
|
cute.printf("stride: {}", tensor.stride)
|
|
|
|
mA = cute.make_tensor(
|
|
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
|
|
)
|
|
cute.print_tensor(mA)
|
|
|
|
|
|
import sys
|
|
import os
|
|
import subprocess
|
|
import shutil
|
|
import tempfile
|
|
import torch
|
|
|
|
|
|
def run_test(tmpdir=None):
|
|
# Skip cleanup if user provides tmpdir
|
|
cleanup = tmpdir is None
|
|
# Initialize temporary build directory
|
|
tmpdir = tmpdir or tempfile.mkdtemp()
|
|
|
|
try:
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
|
|
subprocess.run(["cmake", "--build", tmpdir], check=True)
|
|
|
|
sys.path.append(tmpdir)
|
|
|
|
from tensor import make_tensor, pycapsule_get_pointer
|
|
|
|
# Mock test tensor and corresponding C structure for this example
|
|
# In production, this may come from external library
|
|
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
|
|
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
|
|
c_struct_p = pycapsule_get_pointer(c_struct)
|
|
|
|
# Initialize tensor wrapper and compile test function
|
|
tensor = ExampleTensor(c_struct_p, len(x.shape))
|
|
compiled_func = cute.compile(foo, tensor)
|
|
|
|
# Benchmark pointer access performance
|
|
from time import time
|
|
|
|
start = time()
|
|
# Measure performance of critical path pointer access
|
|
# get C pointers is on critical path to call JIT compiled function
|
|
for _ in range(1000):
|
|
tensor.__c_pointers__()
|
|
end = time()
|
|
print(f"__c_pointers__: {(end - start) * 1000} us")
|
|
|
|
# Execute compiled function
|
|
compiled_func(tensor)
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
if cleanup:
|
|
# Clean up the temporary directory
|
|
shutil.rmtree(tmpdir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Set temporary directory for building C modules"
|
|
)
|
|
parser.add_argument(
|
|
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
run_test(args.tmp_dir)
|