Files
cutlass/examples/python/CuTeDSL/notebooks/composed_layout.ipynb
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

226 lines
7.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "0c7cf795",
"metadata": {},
"source": [
"# Composed Layout in CuTe\n",
"\n",
"A **Composed Layout** is a powerful abstraction in CuTe that enables complex data transformations through \n",
"the composition of layouts and transformations. It provides a flexible way to manipulate memory layouts \n",
"and coordinate systems.\n",
"\n",
"## Components\n",
"\n",
"A Composed Layout consists of three key components:\n",
"\n",
"1. **Inner Layout/Transformation** (`inner`):\n",
" - Can be a layout, swizzle, or custom transformation function\n",
" - Applies the final transformation to the coordinates\n",
" - Supports arbitrary coordinate manipulations\n",
"\n",
"2. **Offset** (`offset`):\n",
" - Typically represented as an integer tuple\n",
" - Adds a constant displacement to coordinates\n",
" - Enables fine-grained control over data positioning\n",
"\n",
"3. **Outer Layout** (`outer`):\n",
" - The layout visible to the user\n",
" - Defines the initial coordinate transformation\n",
" - Determines the shape and organization of the data structure\n",
"\n",
"## Mathematical Representation\n",
"\n",
"The mathematical composition of these components is defined as:\n",
"\n",
"$\n",
"R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c))\n",
"$\n",
"\n",
"Where:\n",
"- $c$ represents the input coordinates\n",
"- $\\circ$ denotes function composition\n",
"- The transformation is applied from right to left\n",
"\n",
"## Usage in Python\n",
"\n",
"To create a Composed Layout in Python, use the `make_composed_layout` function:\n",
"\n",
"```python\n",
"layout = cute.make_composed_layout(inner, offset, outer)\n",
"```\n",
"\n",
"## Key Benefits\n",
"\n",
"1. **Flexibility**: Supports complex transformations that direct composition cannot handle\n",
"2. **Modularity**: Separates different aspects of the transformation\n",
"3. **Performance**: Enables optimized memory access patterns for GPU computations\n",
"4. **Compatibility**: Works with various types of transformations and layouts"
]
},
{
"cell_type": "markdown",
"id": "24448f7d",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"## Custom Transformation Example\n",
"\n",
"This example demonstrates how to create a Composed Layout with a custom transformation function. We'll create a simple transformation that:\n",
"\n",
"1. Takes a 2D coordinate input `(x, y)`\n",
"2. Increments the y-coordinate by 1\n",
"3. Combines this with an offset and identity layout\n",
"\n",
"The example shows how to:\n",
"- Define a custom transformation function\n",
"- Create a composed layout with the transformation\n",
"- Apply the layout to coordinates\n",
"- Print the results for verification"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "184f30e6",
"metadata": {},
"outputs": [],
"source": [
"import cutlass\n",
"import cutlass.cute as cute\n",
"from cutlass.cute.runtime import from_dlpack, make_ptr\n",
"\n",
"\n",
"@cute.jit\n",
"def customized_layout():\n",
" def inner(c):\n",
" x, y = c\n",
" return x, y + 1\n",
"\n",
" layout = cute.make_composed_layout(\n",
" inner, (1, 0), cute.make_identity_layout(shape=(8, 4))\n",
" )\n",
" print(layout)\n",
" cute.printf(layout(0))\n",
"\n",
"\n",
"customized_layout()"
]
},
{
"cell_type": "markdown",
"id": "c897187f",
"metadata": {},
"source": [
"## Gather/Scatter Operations with Composed Layout\n",
"\n",
"Gather and Scatter operations are fundamental data access patterns in parallel computing and GPU programming. In CuTe, we can implement these operations elegantly using Composed Layout.\n",
"\n",
"### Gather Operation\n",
"A gather operation collects elements from a source array using an index array (also called an indirection array). It's defined as:\n",
"```python\n",
"output[i] = source[index[i]]\n",
"```\n",
"\n",
"#### Components in CuTe Implementation:\n",
"1. **Offset Tensor**: Contains the indices for gathering (`offset_tensor`)\n",
"2. **Data Pointer**: Points to the source data array (`data_ptr`)\n",
"3. **Shape**: Defines the shape of logic tensor viewed by user (`shape`)\n",
"\n",
"### How it Works\n",
"1. The inner transformation function reads from the offset tensor:\n",
" ```python\n",
" def inner(c):\n",
" return offset_tensor[c] # Returns the gather index\n",
" ```\n",
"2. The composed layout maps input coordinates through the offset tensor:\n",
" ```python\n",
" gather_layout = cute.make_composed_layout(inner, 0, cute.make_layout(shape))\n",
" ```\n",
"3. This creates an indirect access pattern where:\n",
" - Input coordinate `i` → `offset_tensor[i]` → `data_ptr[offset_tensor[i]]`\n",
"\n",
"4. notably, layout operations like slice, partition can still be applied on `outer` layout\n",
"\n",
"### Use Cases\n",
"- **Sparse Operations**: Accessing non-contiguous memory efficiently\n",
"- **Graph Processing**: Following edge connections in graph algorithms\n",
"- **Feature Embedding**: Looking up embeddings for discrete tokens\n",
"- **Irregular Data Access**: Any pattern requiring indirect memory access\n",
"\n",
"### Example Output Interpretation\n",
"The example code prints pairs of numbers `i -> j` where:\n",
"- `i` is the output index\n",
"- `j` is the gathered source index from `offset_tensor`\n",
"\n",
"This demonstrates how the composed layout transforms coordinates for indirect memory access.\n",
"\n",
"Note: Scatter operations (writing to indirect locations) can be implemented similarly by reversing the data flow direction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d68f9476",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"@cute.jit\n",
"def gather_tensor(\n",
" offset_tensor: cute.Tensor, data_ptr: cute.Pointer, shape: cute.Shape\n",
"):\n",
" def inner(c):\n",
" return offset_tensor[c]\n",
"\n",
" gather_layout = cute.make_composed_layout(inner, 0, cute.make_layout(shape))\n",
" for i in cutlass.range_constexpr(cute.size(shape)):\n",
" cute.printf(\"%d -> %d\", i, gather_layout(i))\n",
"\n",
" # TODO: support in future\n",
" # gather_tensor = cute.make_tensor(data_ptr, gather_layout)\n",
" # cute.printf(gather_tensor[0])\n",
"\n",
"\n",
"shape = (16,)\n",
"offset_tensor = torch.randint(0, 256, shape, dtype=torch.int32)\n",
"data_tensor = torch.arange(0, 256, dtype=torch.int32)\n",
"\n",
"\n",
"gather_tensor(\n",
" from_dlpack(offset_tensor),\n",
" make_ptr(cutlass.Int32, data_tensor.data_ptr(), cute.AddressSpace.generic),\n",
" shape,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv3_12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}