v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -0,0 +1,200 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import cutlass
import torch
import numpy as np
from cutlass.cute.runtime import from_dlpack
"""
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
It shows various ways to allocate different data structures in shared memory:
1. Struct allocation with natural and strict alignment
2. Raw memory block allocation with custom alignment
3. Array allocation with automatic alignment
4. Tensor allocation with layout specification
The example includes:
- Shared storage struct with mixed alignment requirements
- Memory allocation patterns for different data types
- Tensor operations on allocated memory
To run this example:
.. code-block:: bash
python examples/ampere/smem_allocator.py
The example will allocate shared memory, perform tensor operations, and verify the results.
"""
@cute.struct
class complex:
real: cutlass.Float32
imag: cutlass.Float32
# SharedStorage size is 512, alignment is 128
@cute.struct
class SharedStorage:
# struct elements with natural alignment
a: cute.struct.MemRange[cutlass.Float32, 32] # array
b: cutlass.Int64 # saclar
c: complex # nested struct
# struct elements with strict alignment
x: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, 32],
128,
]
y: cute.struct.Align[cutlass.Int32, 8]
z: cute.struct.Align[complex, 16]
@cute.kernel
def kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
# Note: alignment of inital allocator base ptr is 1024
allocator = cutlass.utils.SmemAllocator()
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
# -- Allocate a struct --
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
struct_in_smem = allocator.allocate(SharedStorage)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
# -- Allocate a block of memory --
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
section_in_smem = allocator.allocate(64, byte_alignment=128)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
# -- Allocate an array --
# reserves an int64 array of size 14 in smem, returns the array base ptr
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
# -- Allocate a tensor --
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
# Note: iterator swizzle with swizzle layout is currently not supported
layout = cute.make_layout((16, 2))
tensor_in_smem = allocator.allocate_tensor(
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
# ptr<f16, smem, align<1024>>
# ptr<i64, smem, align<128>>
# ptr<f32, smem, align<8>>
print(struct_in_smem.a.data_ptr())
print(struct_in_smem.b)
print(struct_in_smem.c.real)
# ptr<i8, smem, align<512>>
print(section_in_smem)
# ptr<i64, smem, align<64>>
print(array_in_smem)
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
print(tensor_in_smem)
# fill MemRange tensor in struct and copy to dst
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
a_tensor.fill(const_a)
cute.printf("cute.struct.MemRange: {}", a_tensor)
dst_a.store(a_tensor.load())
# convert block of smem to fill tensor and copy to dst
layout = cute.make_layout((8, 2))
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
sec_tensor = cute.make_tensor(sec_ptr, layout)
sec_tensor.fill(const_b)
cute.printf("block of memory: {}", sec_tensor)
dst_b.store(sec_tensor.load())
# fill allocated tensor in smem and copy to dst
tensor_in_smem.fill(const_c)
cute.printf("tensor in smem: {}", tensor_in_smem)
dst_c.store(tensor_in_smem.load())
@cute.jit
def run_allocation_kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
addtional_bytes = 384
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
grid=(1, 1, 1),
block=(1, 1, 1),
smem=SharedStorage.size_in_bytes() + addtional_bytes,
)
def veify_allocation_kernel(const_a, const_b, const_c):
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
run_allocation_kernel(
const_a,
from_dlpack(dst_a),
const_b,
from_dlpack(dst_b),
const_c,
from_dlpack(dst_c),
)
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
if __name__ == "__main__":
# prepare cuda context
cutlass.cuda.initialize_cuda_context()
# An example for shared memory allocation
const_a = 0.5
const_b = 1.0
const_c = 2.0
veify_allocation_kernel(const_a, const_b, const_c)

View File

@ -0,0 +1,51 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.15)
project(tensor)
# Find Python
find_package(Python COMPONENTS Interpreter Development REQUIRED)
# Get Python site-packages directory using Python
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
OUTPUT_VARIABLE Python_SITE_PACKAGES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
# Add nanobind path to CMAKE_PREFIX_PATH
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
# Find nanobind
find_package(nanobind REQUIRED)
# Add the module
nanobind_add_module(tensor tensor.cpp)

View File

@ -0,0 +1,305 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
The C-structure is defined as:
.. code-block:: c
struct Tensor {
void *ptr; // Pointer to tensor data
int32_t shape[3]; // Tensor dimensions
int32_t strides[3]; // Memory strides for each dimension
};
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
shape, and strides, enabling efficient data passing between different language boundaries.
.. note::
Future development may include automated code generation flows.
"""
import cutlass
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
import cutlass._mlir.extras.types as T
class ExampleTensorValue(ir.Value):
"""A wrapper class for tensor values in MLIR.
This class extends ir.Value to provide convenient access to tensor data pointer,
shape, and strides through MLIR operations.
:type: ir.Value
"""
def __init__(self, v):
"""Initialize a new TensorValue.
:param v: The underlying MLIR value to wrap
:type v: ir.Value
"""
super().__init__(v)
@property
def data_ptr(self, *, loc=None, ip=None):
"""Get the data pointer from the tensor value.
Extracts the data pointer (first field) from the LLVM struct value.
:param loc: Optional location information for MLIR operations
:type loc: Optional[ir.Location]
:param ip: Optional insertion point for MLIR operations
:type ip: Optional[ir.InsertionPoint]
:return: An integer value representing the data pointer
:rtype: ir.Value
"""
# Extract the data pointer from the LLVM struct value
# The data pointer is the first field (index 0) in the struct
# Use llvm.extractvalue to get the pointer field from the struct
ptr_val = llvm.extractvalue(
llvm.PointerType.get(),
self,
[0], # Extract the first field (index 0)
loc=loc,
ip=ip,
)
return cute.make_ptr(cutlass.Float32, ptr_val)
@property
def shape(self):
"""Get the shape of the tensor.
Extracts the shape (second field) from the LLVM struct value.
:return: A tuple of integers representing the tensor dimensions
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the shape field from the LLVM struct value
# The shape is the second field (index 1) in the struct
shape_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[1], # Extract the second field (index 1)
)
# Extract each dimension from the shape struct
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
@property
def stride(self):
"""Get the strides of the tensor.
Extracts the strides (third field) from the LLVM struct value.
:return: A tuple of integers representing the tensor strides
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the strides field from the LLVM struct value
# The strides are the third field (index 2) in the struct
strides_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[2], # Extract the third field (index 2)
)
# Extract each dimension from the strides struct
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
class ExampleTensor:
"""A class representing a tensor with its data pointer, shape, and strides.
This class provides a Python interface to create and manipulate tensor structures
that can be passed to CUTE JIT compiled functions.
:ivar _c_struct_p: The C struct pointer for the tensor
:ivar _rank: The number of dimensions in the tensor
"""
def __init__(self, c_struct_p, rank):
"""Initialize a new Tensor.
:param c_struct_p: The C struct pointer for the tensor
:type c_struct_p: int
:param rank: The number of dimensions in the tensor
:type rank: int
"""
self._c_struct_p = c_struct_p
self._rank = rank
def __get_mlir_types__(self):
"""Get the MLIR types for this tensor.
Creates an LLVM structure type representing a C-structure with:
.. code-block:: c
struct Tensor {
void *ptr;
int32_t shape[3];
int32_t strides[3];
};
:return: A list containing the MLIR struct type
:rtype: list[llvm.StructType]
Create an LLVM structure type that represents a C-structure like:
"""
# Get the number of dimensions from the shape
ndim = self._rank
# Create the pointer type (void*)
ptr_type = llvm.PointerType.get()
# Create array types for shape and strides (int32_t[ndim])
int32_type = ir.IntegerType.get_signless(32)
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
# Create the structure type
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
return [struct_type]
def __new_from_mlir_values__(self, values):
"""Create a new TensorValue from MLIR values.
:param values: A list of MLIR values
:type values: list[ir.Value]
:return: A new TensorValue instance
:rtype: TensorValue
"""
return ExampleTensorValue(values[0])
def __c_pointers__(self):
"""Get the C pointers for this tensor.
:return: A list containing the C struct pointer
:rtype: list[int]
"""
return [self._c_struct_p]
@cute.jit
def foo(tensor):
"""Example JIT function that prints tensor information.
:param tensor: A Tensor instance to print information about
:type tensor: Tensor
"""
cute.printf("data_ptr: {}", tensor.data_ptr)
cute.printf("shape: {}", tensor.shape)
cute.printf("stride: {}", tensor.stride)
mA = cute.make_tensor(
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
)
cute.print_tensor(mA)
import sys
import os
import subprocess
import shutil
import tempfile
import torch
def run_test(tmpdir=None):
# Skip cleanup if user provides tmpdir
cleanup = tmpdir is None
# Initialize temporary build directory
tmpdir = tmpdir or tempfile.mkdtemp()
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
subprocess.run(["cmake", "--build", tmpdir], check=True)
sys.path.append(tmpdir)
from tensor import make_tensor, pycapsule_get_pointer
# Mock test tensor and corresponding C structure for this example
# In production, this may come from external library
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
c_struct_p = pycapsule_get_pointer(c_struct)
# Initialize tensor wrapper and compile test function
tensor = ExampleTensor(c_struct_p, len(x.shape))
compiled_func = cute.compile(foo, tensor)
# Benchmark pointer access performance
from time import time
start = time()
# Measure performance of critical path pointer access
# get C pointers is on critical path to call JIT compiled function
for _ in range(1000):
tensor.__c_pointers__()
end = time()
print(f"__c_pointers__: {(end - start) * 1000} us")
# Execute compiled function
compiled_func(tensor)
except Exception as e:
print(e)
finally:
if cleanup:
# Clean up the temporary directory
shutil.rmtree(tmpdir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Set temporary directory for building C modules"
)
parser.add_argument(
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
)
args = parser.parse_args()
run_test(args.tmp_dir)

View File

@ -0,0 +1,82 @@
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
#include <cstdint>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
// Forward declaration of the MockTensor struct for testing only
struct MockTensor {
void *ptr;
struct {
int32_t shape[3];
} shape;
struct {
int32_t strides[3];
} strides;
};
NB_MODULE(tensor, m) {
// create a tensor for testing
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
std::vector<int32_t> strides) {
auto *tensor = new MockTensor();
tensor->ptr = reinterpret_cast<void *>(ptr);
assert(shape.size() == 3 && "shape must have 3 elements");
assert(strides.size() == 3 && "strides must have 3 elements");
for (size_t i = 0; i < shape.size(); i++) {
tensor->shape.shape[i] = shape[i];
tensor->strides.strides[i] = strides[i];
}
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
auto n = PyCapsule_GetName(capsule);
if (void *p = PyCapsule_GetPointer(capsule, n)) {
delete reinterpret_cast<MockTensor *>(p);
}
}));
});
m.def(
"pycapsule_get_pointer",
[](nb::object &capsule) {
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
if (!ptr) {
throw std::runtime_error("Invalid tensor capsule");
}
return reinterpret_cast<uintptr_t>(ptr);
},
"Get pointer from PyCapsule");
}

File diff suppressed because it is too large Load Diff

View File

@ -83,11 +83,6 @@
"\n",
" # Print hello world from host code\n",
" cute.printf(\"hello world\")\n",
" \n",
" # Initialize CUDA context for launching a kernel with error checking\n",
" # We make context initialization explicit to allow users to control the context creation \n",
" # and avoid potential issues with multiple contexts\n",
" cutlass.cuda.initialize_cuda_context()\n",
"\n",
" # Launch kernel\n",
" kernel().launch(\n",
@ -129,6 +124,11 @@
}
],
"source": [
"# Initialize CUDA context for launching a kernel with error checking\n",
"# We make context initialization explicit to allow users to control the context creation \n",
"# and avoid potential issues with multiple contexts\n",
"cutlass.cuda.initialize_cuda_context()\n",
"\n",
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
"print(\"Running hello_world()...\")\n",
"hello_world()\n",
@ -136,6 +136,7 @@
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
"print(\"Compiling...\")\n",
"hello_world_compiled = cute.compile(hello_world)\n",
"\n",
"# Run the pre-compiled version\n",
"print(\"Running compiled version...\")\n",
"hello_world_compiled()"