Files
cutlass/examples/python/CuTeDSL/notebooks/tensor.ipynb
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

331 lines
11 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cutlass\n",
"import cutlass.cute as cute"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensor\n",
"\n",
"A tensor in CuTe is created through the composition of two key components:\n",
"\n",
"1. An **Engine** (E) - A random-access, pointer-like object that supports:\n",
" - Offset operation: `e + d → e` (offset engine by elements of a layout's codomain)\n",
" - Dereference operation: `*e → v` (dereference engine to produce value)\n",
"\n",
"2. A **Layout** (L) - Defines the mapping from coordinates to offsets\n",
"\n",
"A tensor is formally defined as the composition of an engine E with a layout L, expressed as `T = E ∘ L`. When evaluating a tensor at coordinate c, it:\n",
"\n",
"1. Maps the coordinate c to the codomain using the layout\n",
"2. Offsets the engine accordingly\n",
"3. Dereferences the result to obtain the tensor's value\n",
"\n",
"This can be expressed mathematically as:\n",
"\n",
"```\n",
"T(c) = (E ∘ L)(c) = *(E + L(c))\n",
"```\n",
"\n",
"## Example Usage\n",
"\n",
"Here's a simple example of creating a tensor using pointer and layout `(8,5):(5,1)` and fill with ones:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@cute.jit\n",
"def create_tensor_from_ptr(ptr: cute.Pointer):\n",
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
" tensor = cute.make_tensor(ptr, layout)\n",
" tensor.fill(1)\n",
" cute.print_tensor(tensor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This creates a tensor where:\n",
"- The engine is a pointer\n",
"- The layout with shape `(8, 5)` and stride `(5, 1)`\n",
"- The resulting tensor can be evaluated using coordinates defined by the layout\n",
"\n",
"We can test this by allocating buffer with torch and run test with pointer to torch tensor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from cutlass.torch import dtype as torch_dtype\n",
"import cutlass.cute.runtime as cute_rt\n",
"\n",
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
"ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr())\n",
"\n",
"create_tensor_from_ptr(ptr_a)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DLPACK support \n",
"\n",
"CuTe DSL is designed to support dlpack protocol natively. This offers easy integration with frameworks \n",
"supporting DLPack, e.g. torch, numpy, jax, tensorflow, etc.\n",
"\n",
"For more information, please refer to DLPACK project: https://github.com/dmlc/dlpack\n",
"\n",
"Calling `from_dlpack` can convert any tensor or ndarray object supporting `__dlpack__` and `__dlpack_device__`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cutlass.cute.runtime import from_dlpack\n",
"\n",
"\n",
"@cute.jit\n",
"def print_tensor_dlpack(src: cute.Tensor):\n",
" print(src)\n",
" cute.print_tensor(src)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
"\n",
"print_tensor_dlpack(from_dlpack(a))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"a = np.random.randn(8, 8).astype(np.float32)\n",
"\n",
"print_tensor_dlpack(from_dlpack(a))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensor Evaluation Methods\n",
"\n",
"Tensors support two primary methods of evaluation:\n",
"\n",
"### 1. Full Evaluation\n",
"When applying the tensor evaluation with a complete coordinate c, it computes the offset, applies it to the engine, \n",
"and dereferences it to return the stored value. This is the straightforward case where you want to access \n",
"a specific element of the tensor.\n",
"\n",
"### 2. Partial Evaluation (Slicing)\n",
"When evaluating with an incomplete coordinate c = c' ⊕ c* (where c* represents the unspecified portion), \n",
"the result is a new tensor which is a slice of the original tensor with its engine offset to account for \n",
"the coordinates that were provided. This operation can be expressed as:\n",
"\n",
"```\n",
"T(c) = (E ∘ L)(c) = (E + L(c')) ∘ L(c*) = T'(c*)\n",
"```\n",
"\n",
"Slicing effectively reduces the dimensionality of the tensor, creating a sub-tensor that can be \n",
"further evaluated or manipulated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@cute.jit\n",
"def tensor_access_item(a: cute.Tensor):\n",
" # access data using linear index\n",
" cute.printf(\n",
" \"a[2] = {} (equivalent to a[{}])\",\n",
" a[2],\n",
" cute.make_identity_tensor(a.layout.shape)[2],\n",
" )\n",
" cute.printf(\n",
" \"a[9] = {} (equivalent to a[{}])\",\n",
" a[9],\n",
" cute.make_identity_tensor(a.layout.shape)[9],\n",
" )\n",
"\n",
" # access data using n-d coordinates, following two are equivalent\n",
" cute.printf(\"a[2,0] = {}\", a[2, 0])\n",
" cute.printf(\"a[2,4] = {}\", a[2, 4])\n",
" cute.printf(\"a[(2,4)] = {}\", a[2, 4])\n",
"\n",
" # assign value to tensor@(2,4)\n",
" a[2, 3] = 100.0\n",
" a[2, 4] = 101.0\n",
" cute.printf(\"a[2,3] = {}\", a[2, 3])\n",
" cute.printf(\"a[(2,4)] = {}\", a[(2, 4)])\n",
"\n",
"\n",
"# Create a tensor with sequential data using torch\n",
"data = torch.arange(0, 8 * 5, dtype=torch.float32).reshape(8, 5)\n",
"tensor_access_item(from_dlpack(data))\n",
"\n",
"print(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tensor as memory view\n",
"\n",
"In CUDA programming, different memory spaces have different characteristics in terms of access speed, scope, and lifetime:\n",
"\n",
"- **generic**: Default memory space that can refer to any other memory space.\n",
"- **global memory (gmem)**: Accessible by all threads across all blocks, but has higher latency.\n",
"- **shared memory (smem)**: Accessible by all threads within a block, with much lower latency than global memory.\n",
"- **register memory (rmem)**: Thread-private memory with the lowest latency, but limited capacity.\n",
"- **tensor memory (tmem)**: Specialized memory introduced in NVIDIA Blackwell architecture for tensor operations.\n",
"\n",
"When creating tensors in CuTe, you can specify the memory space to optimize performance based on your access patterns.\n",
"\n",
"For more information on CUDA memory spaces, see the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Coordinate Tensors\n",
"\n",
"### Definition and Properties\n",
"\n",
"A coordinate tensor $T: Z^n → Z^m$ is a mathematical structure that establishes a mapping between coordinate spaces. Unlike standard tensors that map coordinates to scalar values, coordinate tensors map coordinates to other coordinates, forming a fundamental building block for tensor operations and transformations.\n",
"\n",
"### Examples\n",
"\n",
"Consider a `(4,4)` coordinate tensor:\n",
"\n",
"**Row-Major Layout (C-style):**\n",
"\\begin{bmatrix} \n",
"(0,0) & (0,1) & (0,2) & (0,3) \\\\\n",
"(1,0) & (1,1) & (1,2) & (1,3) \\\\\n",
"(2,0) & (2,1) & (2,2) & (2,3) \\\\\n",
"(3,0) & (3,1) & (3,2) & (3,3)\n",
"\\end{bmatrix}\n",
"\n",
"**Column-Major Layout (Fortran-style):**\n",
"\\begin{bmatrix}\n",
"(0,0) & (1,0) & (2,0) & (3,0) \\\\\n",
"(0,1) & (1,1) & (2,1) & (3,1) \\\\\n",
"(0,2) & (1,2) & (2,2) & (3,2) \\\\\n",
"(0,3) & (1,3) & (2,3) & (3,3)\n",
"\\end{bmatrix}\n",
"\n",
"### Identity Tensor\n",
"\n",
"An identity tensor $I$ is a special case of a coordinate tensor that implements the identity mapping function:\n",
"\n",
"**Definition:**\n",
"For a given shape $S = (s_1, s_2, ..., s_n)$, the identity tensor $I$ satisfies: $I(c) = c, \\forall c \\in \\prod_{i=1}^n [0, s_i)$\n",
"\n",
"**Properties:**\n",
"1. **Bijective Mapping**: The identity tensor establishes a one-to-one correspondence between coordinates.\n",
"2. **Layout Invariance**: The logical structure remains constant regardless of the underlying memory layout.\n",
"3. **Coordinate Preservation**: For any coordinate c, I(c) = c.\n",
"\n",
"\n",
"CuTe establishes an isomorphism between 1-D indices and N-D coordinates through lexicographical ordering. For a coordinate c = (c₁, c₂, ..., cₙ) in an identity tensor with shape S = (s₁, s₂, ..., sₙ):\n",
"\n",
"**Linear Index Formula:**\n",
"$\\text{idx} = c_1 + \\sum_{i=2}^{n} \\left(c_i \\prod_{j=1}^{i-1} s_j\\right)$\n",
"\n",
"**Example:**\n",
"```python\n",
"# Create an identity tensor from a given shape\n",
"coord_tensor = make_identity_tensor(layout.shape())\n",
"\n",
"# Access coordinate using linear index\n",
"coord = coord_tensor[linear_idx] # Returns the N-D coordinate\n",
"```\n",
"\n",
"This bidirectional mapping enables efficient conversion from linear indices to N-dimensional coordinates, facilitating tensor operations and memory access patterns."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@cute.jit\n",
"def print_tensor_coord(a: cute.Tensor):\n",
" coord_tensor = cute.make_identity_tensor(a.layout.shape)\n",
" print(coord_tensor)\n",
" cute.print_tensor(coord_tensor)\n",
"\n",
"\n",
"a = torch.randn(8, 4, dtype=torch_dtype(cutlass.Float32))\n",
"print_tensor_coord(from_dlpack(a))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}