Files
cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py
2025-06-06 02:39:20 -04:00

306 lines
10 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
The C-structure is defined as:
.. code-block:: c
struct Tensor {
void *ptr; // Pointer to tensor data
int32_t shape[3]; // Tensor dimensions
int32_t strides[3]; // Memory strides for each dimension
};
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
shape, and strides, enabling efficient data passing between different language boundaries.
.. note::
Future development may include automated code generation flows.
"""
import cutlass
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
import cutlass._mlir.extras.types as T
class ExampleTensorValue(ir.Value):
"""A wrapper class for tensor values in MLIR.
This class extends ir.Value to provide convenient access to tensor data pointer,
shape, and strides through MLIR operations.
:type: ir.Value
"""
def __init__(self, v):
"""Initialize a new TensorValue.
:param v: The underlying MLIR value to wrap
:type v: ir.Value
"""
super().__init__(v)
@property
def data_ptr(self, *, loc=None, ip=None):
"""Get the data pointer from the tensor value.
Extracts the data pointer (first field) from the LLVM struct value.
:param loc: Optional location information for MLIR operations
:type loc: Optional[ir.Location]
:param ip: Optional insertion point for MLIR operations
:type ip: Optional[ir.InsertionPoint]
:return: An integer value representing the data pointer
:rtype: ir.Value
"""
# Extract the data pointer from the LLVM struct value
# The data pointer is the first field (index 0) in the struct
# Use llvm.extractvalue to get the pointer field from the struct
ptr_val = llvm.extractvalue(
llvm.PointerType.get(),
self,
[0], # Extract the first field (index 0)
loc=loc,
ip=ip,
)
return cute.make_ptr(cutlass.Float32, ptr_val)
@property
def shape(self):
"""Get the shape of the tensor.
Extracts the shape (second field) from the LLVM struct value.
:return: A tuple of integers representing the tensor dimensions
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the shape field from the LLVM struct value
# The shape is the second field (index 1) in the struct
shape_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[1], # Extract the second field (index 1)
)
# Extract each dimension from the shape struct
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
@property
def stride(self):
"""Get the strides of the tensor.
Extracts the strides (third field) from the LLVM struct value.
:return: A tuple of integers representing the tensor strides
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the strides field from the LLVM struct value
# The strides are the third field (index 2) in the struct
strides_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[2], # Extract the third field (index 2)
)
# Extract each dimension from the strides struct
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
class ExampleTensor:
"""A class representing a tensor with its data pointer, shape, and strides.
This class provides a Python interface to create and manipulate tensor structures
that can be passed to CUTE JIT compiled functions.
:ivar _c_struct_p: The C struct pointer for the tensor
:ivar _rank: The number of dimensions in the tensor
"""
def __init__(self, c_struct_p, rank):
"""Initialize a new Tensor.
:param c_struct_p: The C struct pointer for the tensor
:type c_struct_p: int
:param rank: The number of dimensions in the tensor
:type rank: int
"""
self._c_struct_p = c_struct_p
self._rank = rank
def __get_mlir_types__(self):
"""Get the MLIR types for this tensor.
Creates an LLVM structure type representing a C-structure with:
.. code-block:: c
struct Tensor {
void *ptr;
int32_t shape[3];
int32_t strides[3];
};
:return: A list containing the MLIR struct type
:rtype: list[llvm.StructType]
Create an LLVM structure type that represents a C-structure like:
"""
# Get the number of dimensions from the shape
ndim = self._rank
# Create the pointer type (void*)
ptr_type = llvm.PointerType.get()
# Create array types for shape and strides (int32_t[ndim])
int32_type = ir.IntegerType.get_signless(32)
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
# Create the structure type
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
return [struct_type]
def __new_from_mlir_values__(self, values):
"""Create a new TensorValue from MLIR values.
:param values: A list of MLIR values
:type values: list[ir.Value]
:return: A new TensorValue instance
:rtype: TensorValue
"""
return ExampleTensorValue(values[0])
def __c_pointers__(self):
"""Get the C pointers for this tensor.
:return: A list containing the C struct pointer
:rtype: list[int]
"""
return [self._c_struct_p]
@cute.jit
def foo(tensor):
"""Example JIT function that prints tensor information.
:param tensor: A Tensor instance to print information about
:type tensor: Tensor
"""
cute.printf("data_ptr: {}", tensor.data_ptr)
cute.printf("shape: {}", tensor.shape)
cute.printf("stride: {}", tensor.stride)
mA = cute.make_tensor(
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
)
cute.print_tensor(mA)
import sys
import os
import subprocess
import shutil
import tempfile
import torch
def run_test(tmpdir=None):
# Skip cleanup if user provides tmpdir
cleanup = tmpdir is None
# Initialize temporary build directory
tmpdir = tmpdir or tempfile.mkdtemp()
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
subprocess.run(["cmake", "--build", tmpdir], check=True)
sys.path.append(tmpdir)
from tensor import make_tensor, pycapsule_get_pointer
# Mock test tensor and corresponding C structure for this example
# In production, this may come from external library
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
c_struct_p = pycapsule_get_pointer(c_struct)
# Initialize tensor wrapper and compile test function
tensor = ExampleTensor(c_struct_p, len(x.shape))
compiled_func = cute.compile(foo, tensor)
# Benchmark pointer access performance
from time import time
start = time()
# Measure performance of critical path pointer access
# get C pointers is on critical path to call JIT compiled function
for _ in range(1000):
tensor.__c_pointers__()
end = time()
print(f"__c_pointers__: {(end - start) * 1000} us")
# Execute compiled function
compiled_func(tensor)
except Exception as e:
print(e)
finally:
if cleanup:
# Clean up the temporary directory
shutil.rmtree(tmpdir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Set temporary directory for building C modules"
)
parser.add_argument(
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
)
args = parser.parse_args()
run_test(args.tmp_dir)