* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
1219 lines
47 KiB
Plaintext
1219 lines
47 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"editable": true,
|
||
"slideshow": {
|
||
"slide_type": ""
|
||
},
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from functools import partial\n",
|
||
"from typing import List\n",
|
||
"\n",
|
||
"import cutlass\n",
|
||
"import cutlass.cute as cute\n",
|
||
"from cutlass.cute.runtime import from_dlpack"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Kernel Tutorial: Building an Efficient Elementwise Add Kernel with CuTe DSL\n",
|
||
"\n",
|
||
"This tutorial demonstrates how to implement and optimize a GPU elementwise addition kernel using the CuTe DSL. \n",
|
||
"\n",
|
||
"## Learning Objectives\n",
|
||
"\n",
|
||
"In this tutorial, you will learn building an efficient elementwise kernel in CuTe DSL step by step:\n",
|
||
"- How to implement basic GPU kernels using CuTe DSL in basic CUDA techniques\n",
|
||
"- How to benchmark performance of the kernel\n",
|
||
"- How to tile and partition tensor and map to basic CuTe Layout\n",
|
||
"- What it Thread & Value Layout and mapping from thread & value index to logical coordinate\n",
|
||
"- How to implement advanced kernel with TV layout and tune performance to achieve peak performance\n",
|
||
"\n",
|
||
"## Understanding Elementwise Addition\n",
|
||
"\n",
|
||
"Elementwise addition is a fundamental operation in linear algebra and deep learning. Given two tensors of the same shape, the operation performs element-wise addition to produce a result tensor of the same shape.\n",
|
||
"\n",
|
||
"For two 2D tensors $A$ and $B$ of shape $(M, N)$, the elementwise addition operation $C = A + B$ is defined as:\n",
|
||
"\n",
|
||
"$\n",
|
||
" C_{i,j} = A_{i,j} + B_{i,j}\n",
|
||
"$\n",
|
||
"\n",
|
||
"where:\n",
|
||
"- $i \\in [0, M-1]$ represents the row index\n",
|
||
"- $j \\in [0, N-1]$ represents the column index\n",
|
||
"- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ in tensors $A$, $B$, and $C$ respectively\n",
|
||
"\n",
|
||
"This operation has several important characteristics:\n",
|
||
"1. **Parallelizable**: Each element can be computed independently\n",
|
||
"2. **Memory-bound**: Performance limited by memory bandwidth rather than compute\n",
|
||
"3. **Coalescing-sensitive**: Efficiency depends on memory access patterns\n",
|
||
"4. **Vectorization-friendly**: Multiple elements can be processed together\n",
|
||
"\n",
|
||
"## Naive Elementwise Add Kernel\n",
|
||
"\n",
|
||
"Let's start with a naive implementation to establish a baseline before exploring optimizations."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Basic Kernel Implementation\n",
|
||
"# ---------------------\n",
|
||
"# This is our first implementation of the elementwise add kernel.\n",
|
||
"# It follows a simple 1:1 mapping between threads and tensor elements.\n",
|
||
"\n",
|
||
"\n",
|
||
"@cute.kernel\n",
|
||
"def naive_elementwise_add_kernel(\n",
|
||
" gA: cute.Tensor, # Input tensor A\n",
|
||
" gB: cute.Tensor, # Input tensor B\n",
|
||
" gC: cute.Tensor, # Output tensor C = A + B\n",
|
||
"):\n",
|
||
" # Step 1: Get thread indices\n",
|
||
" # ------------------------\n",
|
||
" # CUDA threads are organized in a 3D grid of thread blocks\n",
|
||
" # Here we only use the x-dimension for simplicity\n",
|
||
" tidx, _, _ = cute.arch.thread_idx() # Thread index within block (0 to bdim-1)\n",
|
||
" bidx, _, _ = cute.arch.block_idx() # Block index in grid (0 to grid_dim-1)\n",
|
||
" bdim, _, _ = cute.arch.block_dim() # Number of threads per block\n",
|
||
"\n",
|
||
" # Calculate global thread index\n",
|
||
" # This gives each thread a unique ID across all blocks\n",
|
||
" thread_idx = bidx * bdim + tidx # Global thread ID\n",
|
||
"\n",
|
||
" # Step 2: Map thread index to tensor coordinates\n",
|
||
" # -------------------------------------------\n",
|
||
" # Each thread will process one element of the input tensors\n",
|
||
" m, n = gA.shape # Get tensor dimensions (M rows × N columns)\n",
|
||
"\n",
|
||
" # Convert linear thread index to 2D coordinates:\n",
|
||
" # - ni: column index (0 to n-1)\n",
|
||
" # - mi: row index (0 to m-1)\n",
|
||
" ni = thread_idx % n # Column index (faster varying dimension)\n",
|
||
" mi = thread_idx // n # Row index (slower varying dimension)\n",
|
||
"\n",
|
||
" # Step 3: Load and process data\n",
|
||
" # ---------------------------\n",
|
||
" # Load values from input tensors\n",
|
||
" # The tensor layout automatically handles the conversion from\n",
|
||
" # logical indices (mi, ni) to physical memory addresses\n",
|
||
" a_val = gA[mi, ni] # Load element from tensor A\n",
|
||
" b_val = gB[mi, ni] # Load element from tensor B\n",
|
||
"\n",
|
||
" # Step 4: Store result\n",
|
||
" # ------------------\n",
|
||
" # Write the sum back to the output tensor\n",
|
||
" gC[mi, ni] = a_val + b_val"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Structure of the Kernel\n",
|
||
"\n",
|
||
"The naive kernel implementation follows a straightforward but effective structure for parallel processing on the GPU. Here's a detailed breakdown of how it works:\n",
|
||
"\n",
|
||
"1. **Thread Organization and Indexing**\n",
|
||
" - Each CUDA thread is uniquely identified using a combination of:\n",
|
||
" * `thread_idx` (tidx): Thread index within a block (0 to bdim-1)\n",
|
||
" * `block_idx` (bidx): Block index in the grid\n",
|
||
" * `block_dim` (bdim): Number of threads per block\n",
|
||
" - Global thread ID is calculated as: `thread_idx = bidx * bdim + tidx`\n",
|
||
"\n",
|
||
"2. **Coordinate Mapping**\n",
|
||
" - The kernel maps each thread's global ID to 2D tensor coordinates:\n",
|
||
" * `ni = thread_idx % n` (column index - faster varying)\n",
|
||
" * `mi = thread_idx // n` (row index - slower varying)\n",
|
||
" - This mapping ensures coalesced memory access by having adjacent threads access adjacent memory locations\n",
|
||
"\n",
|
||
"3. **Memory Access Pattern**\n",
|
||
" - Each thread:\n",
|
||
" * Loads one element from tensor A: `a_val = gA[mi, ni]`\n",
|
||
" * Loads one element from tensor B: `b_val = gB[mi, ni]`\n",
|
||
" * Performs addition: `a_val + b_val`\n",
|
||
" * Stores result to tensor C: `gC[mi, ni] = result`\n",
|
||
" - Memory Considerations\n",
|
||
" * Uses 1:1 thread-to-element mapping\n",
|
||
" * Memory accesses are coalesced when threads in a warp access consecutive elements\n",
|
||
" * No explicit use of shared memory or register blocking\n",
|
||
" * Limited ability to hide memory latency due to single element processing\n",
|
||
"\n",
|
||
"This naive implementation provides a baseline for understanding more optimized versions that follow, which introduce:\n",
|
||
"- Vectorized memory access\n",
|
||
"- Thread and value (TV) layouts\n",
|
||
"- Advanced tiling strategies\n",
|
||
"- Custom binary operations\n",
|
||
"\n",
|
||
"For more details about coalesced memory access, please read: https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#coalesced-access-to-global-memory\n",
|
||
"\n",
|
||
"\n",
|
||
"### Kernel Launch Configuration and Testing\n",
|
||
"\n",
|
||
"This section demonstrates how to:\n",
|
||
"1. Configure and launch the kernel with `cute.jit` function\n",
|
||
"2. Set up test data with `torch`\n",
|
||
"3. Verify correctness\n",
|
||
"\n",
|
||
"**Launch Configuration**\n",
|
||
" - Uses 256 threads per block (common choice for good occupancy)\n",
|
||
" - Grid size calculated based on total elements: `(m * n) // threads_per_block`\n",
|
||
" - Single dimension block and grid configuration for simplicity\n",
|
||
"\n",
|
||
"#### Host JIT function to launch kernel"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.jit # Just-in-time compilation decorator\n",
|
||
"def naive_elementwise_add(\n",
|
||
" mA: cute.Tensor, # Input tensor A\n",
|
||
" mB: cute.Tensor, # Input tensor B\n",
|
||
" mC: cute.Tensor, # Output tensor C\n",
|
||
"):\n",
|
||
" # Configure kernel launch parameters\n",
|
||
" # --------------------------------\n",
|
||
" # Choose number of threads per block\n",
|
||
" # 256 is a common choice as it:\n",
|
||
" # - Allows good occupancy on most GPUs\n",
|
||
" # - Is a multiple of 32 (warp size)\n",
|
||
" # - Provides enough threads for latency hiding\n",
|
||
" num_threads_per_block = 256\n",
|
||
"\n",
|
||
" # Get input dimensions\n",
|
||
" m, n = mA.shape # Matrix dimensions (M rows × N columns)\n",
|
||
"\n",
|
||
" # Create kernel instance\n",
|
||
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
|
||
"\n",
|
||
" # Launch kernel with calculated grid dimensions\n",
|
||
" # -------------------------------------------\n",
|
||
" # Grid size calculation:\n",
|
||
" # - Total elements: m * n\n",
|
||
" # - Blocks needed: ceil(total_elements / threads_per_block)\n",
|
||
" # - Using integer division here assumes m * n is multiple of threads_per_block\n",
|
||
" kernel.launch(\n",
|
||
" grid=((m * n) // num_threads_per_block, 1, 1), # Number of blocks in x,y,z\n",
|
||
" block=(num_threads_per_block, 1, 1), # Threads per block in x,y,z\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Setup test data with torch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Test Setup\n",
|
||
"# ----------\n",
|
||
"# Define test dimensions\n",
|
||
"M, N = 16384, 8192 # Using large matrices to measure performance\n",
|
||
"\n",
|
||
"# Create test data on GPU\n",
|
||
"# ----------------------\n",
|
||
"# Using float16 (half precision) for:\n",
|
||
"# - Reduced memory bandwidth requirements\n",
|
||
"# - Better performance on modern GPUs\n",
|
||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16) # Random input A\n",
|
||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16) # Random input B\n",
|
||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16) # Output buffer\n",
|
||
"\n",
|
||
"# Calculate total elements for bandwidth calculations\n",
|
||
"num_elements = sum([a.numel(), b.numel(), c.numel()])\n",
|
||
"\n",
|
||
"# Convert PyTorch tensors to CuTe tensors\n",
|
||
"# -------------------------------------\n",
|
||
"# from_dlpack creates CuTe tensor views of PyTorch tensors\n",
|
||
"# assumed_align=16 ensures proper memory alignment for vectorized access\n",
|
||
"a_ = from_dlpack(a, assumed_align=16) # CuTe tensor A\n",
|
||
"b_ = from_dlpack(b, assumed_align=16) # CuTe tensor B\n",
|
||
"c_ = from_dlpack(c, assumed_align=16) # CuTe tensor C"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Compile and run"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Compile the kernel for the specific input types\n",
|
||
"naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n",
|
||
"\n",
|
||
"# Run the kernel\n",
|
||
"naive_elementwise_add_(a_, b_, c_)\n",
|
||
"\n",
|
||
"# Verify Results\n",
|
||
"# -------------\n",
|
||
"# Compare our kernel output with PyTorch's native implementation\n",
|
||
"torch.testing.assert_close(c, a + b) # Raises error if results don't match"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Performance Analysis and Benchmarking\n",
|
||
"\n",
|
||
"To understand and improve our kernel's performance, we need to measure its execution time and memory throughput. Let's analyze several key metrics:\n",
|
||
"\n",
|
||
"* **Execution Time**\n",
|
||
" - Measures raw kernel performance in microseconds\n",
|
||
" - Lower is better\n",
|
||
" - Affected by GPU clock speed, memory bandwidth, and kernel efficiency\n",
|
||
"* **Memory Throughput**\n",
|
||
" - Measures how fast we can copy data (GB/s)\n",
|
||
" - Higher is better\n",
|
||
" - Theoretical peak varies by GPU model\n",
|
||
" - For elementwise add:\n",
|
||
" * Read: 2 elements (A and B)\n",
|
||
" * Write: 1 element (C)\n",
|
||
" * Total bytes = (2 reads + 1 write) × elements × sizeof(dtype)\n",
|
||
"\n",
|
||
"Below is our benchmarking utility that measures these metrics:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def benchmark(callable, a_, b_, c_):\n",
|
||
" avg_time_us = cute.testing.benchmark(\n",
|
||
" callable,\n",
|
||
" kernel_arguments=cute.testing.JitArguments(a_, b_, c_),\n",
|
||
" warmup_iterations=5,\n",
|
||
" iterations=100,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Calculate metrics\n",
|
||
" # ----------------\n",
|
||
" dtype = a_.element_type\n",
|
||
"\n",
|
||
" # Calculate total bytes transferred:\n",
|
||
" # - 2 reads (A and B) + 1 write (C)\n",
|
||
" # - Each element is dtype.width bits\n",
|
||
" bytes_per_element = dtype.width // 8\n",
|
||
" total_bytes = num_elements * bytes_per_element\n",
|
||
"\n",
|
||
" # Calculate achieved bandwidth\n",
|
||
" achieved_bandwidth = total_bytes / (avg_time_us * 1000) # GB/s\n",
|
||
"\n",
|
||
" # Print results\n",
|
||
" # ------------\n",
|
||
" print(f\"Performance Metrics:\")\n",
|
||
" print(f\"-------------------\")\n",
|
||
" print(f\"Kernel execution time: {avg_time_us:.4f} us\")\n",
|
||
" print(f\"Memory throughput: {achieved_bandwidth:.2f} GB/s\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"benchmark(naive_elementwise_add_, a_, b_, c_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Theoretical Analysis\n",
|
||
"\n",
|
||
"This section analyze the performance characteristics and optimization opportunities of our elementwise addition kernel through several theoretical frameworks.\n",
|
||
"\n",
|
||
"#### Little's Law\n",
|
||
"\n",
|
||
"Little's Law provides crucial insights into relationship between latency, bandwidth and inflight operations:\n",
|
||
"\n",
|
||
"$ L = \\lambda \\times W $\n",
|
||
"\n",
|
||
"Where:\n",
|
||
"- $L$: Number of in-flight memory operations needed\n",
|
||
"- $\\lambda$: Target memory bandwidth (bytes/cycle)\n",
|
||
"- $W$: Memory system latency (cycles)\n",
|
||
"\n",
|
||
"According to *Little's Law*, naive implementation has\n",
|
||
" - 1 element (4 bytes load + 2 bytes store) per thread\n",
|
||
" - 256 threads/block × N blocks\n",
|
||
" - Limited in-flight operations\n",
|
||
"\n",
|
||
"In some GPUs, it's insufficient parallelism to saturate memory bandwidth.\n",
|
||
"\n",
|
||
"#### Optimization Strategies\n",
|
||
"\n",
|
||
"Based on this analysis, one commonly used technique is **Vectorization**. Instead of 1 element \n",
|
||
"per load per thread, vectorization allows multiple element per load\n",
|
||
" - Reduces instruction count\n",
|
||
" - Improves memory coalescing\n",
|
||
" - Increases operations in flight"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Vectorized Load and Store\n",
|
||
"\n",
|
||
"To improve performance according to Little's Law, we need to increase the number\n",
|
||
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
|
||
"in each load & store operation per thread through vectorized memory access.\n",
|
||
"\n",
|
||
"Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n",
|
||
"we can load 4 elements per vectorized operation on contiguous rows.\n",
|
||
"CuTe tiling operations make this vectorization straightforward.\n",
|
||
"\n",
|
||
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
|
||
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
|
||
"as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n",
|
||
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
|
||
"\n",
|
||
"```python\n",
|
||
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
|
||
"gA = cute.zipped_divide(a, tiler=(1, 4)) # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n",
|
||
"```\n",
|
||
"\n",
|
||
"$\n",
|
||
" \\begin{array}{ccccc}\n",
|
||
" & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n",
|
||
" & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n",
|
||
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
|
||
" \\end{array}\n",
|
||
"$"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.kernel\n",
|
||
"def vectorized_elementwise_add_kernel(\n",
|
||
" gA: cute.Tensor,\n",
|
||
" gB: cute.Tensor,\n",
|
||
" gC: cute.Tensor,\n",
|
||
"):\n",
|
||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||
" bdim, _, _ = cute.arch.block_dim()\n",
|
||
"\n",
|
||
" thread_idx = bidx * bdim + tidx\n",
|
||
"\n",
|
||
" # Map thread index to logical index of input tensor in unit of vector\n",
|
||
" m, n = gA.shape[1] # thread-domain\n",
|
||
" ni = thread_idx % n\n",
|
||
" mi = thread_idx // n\n",
|
||
"\n",
|
||
" # Map logical index to physical address via tensor layout\n",
|
||
" a_val = gA[(None, (mi, ni))].load()\n",
|
||
" b_val = gB[(None, (mi, ni))].load()\n",
|
||
" print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n",
|
||
" print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n",
|
||
"\n",
|
||
" # Perform element-wise addition\n",
|
||
" gC[(None, (mi, ni))] = a_val + b_val"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
|
||
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
|
||
"we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n",
|
||
"\n",
|
||
"$ gA[(None, (mi, ni))]: $\n",
|
||
"\n",
|
||
"$\n",
|
||
" \\begin{array}{ccccc}\n",
|
||
" Layout: & ( & (1,4) & , & (2048,512) & ) & : & ((0,1),(2048,4)) & \\xrightarrow{\\text{slice}} & ((1,4)):((0,1)) \\\\\n",
|
||
" & & \\underbrace{\\phantom{(1,4)}} & & \\underbrace{\\phantom{(2048,512)}} & & \\\\\n",
|
||
" Coord: & ( & None & , & (mi, ni) & ) & &\n",
|
||
" \\end{array}\n",
|
||
"$\n",
|
||
"\n",
|
||
"Then tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n",
|
||
"\n",
|
||
"```python\n",
|
||
"v0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 4 + 0)]\n",
|
||
"v1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 4 + 1)]\n",
|
||
"v2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 4 + 2)]\n",
|
||
"v3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 4 + 3)]\n",
|
||
"```\n",
|
||
"\n",
|
||
"### Assumed Alignment\n",
|
||
"\n",
|
||
"In order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \n",
|
||
"It's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n",
|
||
"\n",
|
||
"```python\n",
|
||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||
"\n",
|
||
"# Compile kernel with alignment assumption\n",
|
||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||
"```\n",
|
||
"\n",
|
||
"It's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\n",
|
||
"during sub-slice."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n",
|
||
" threads_per_block = 256\n",
|
||
"\n",
|
||
" gA = cute.zipped_divide(mA, (1, 4))\n",
|
||
" gB = cute.zipped_divide(mB, (1, 4))\n",
|
||
" gC = cute.zipped_divide(mC, (1, 4))\n",
|
||
"\n",
|
||
" print(\"[DSL INFO] Tiled Tensors:\")\n",
|
||
" print(f\"[DSL INFO] gA = {gA}\")\n",
|
||
" print(f\"[DSL INFO] gB = {gB}\")\n",
|
||
" print(f\"[DSL INFO] gC = {gC}\")\n",
|
||
"\n",
|
||
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
|
||
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
|
||
" block=(threads_per_block, 1, 1),\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"\n",
|
||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||
"\n",
|
||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||
"compiled_func(a_, b_, c_)\n",
|
||
"\n",
|
||
"# verify correctness\n",
|
||
"torch.testing.assert_close(c, a + b)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"benchmark(compiled_func, a_, b_, c_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## TV Layout\n",
|
||
"\n",
|
||
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
|
||
"to physical addresses in two steps:\n",
|
||
"\n",
|
||
"Step 1: Map thread index to logical coordinates in `(M, N)`\n",
|
||
"\n",
|
||
"* `mi = thread_idx // n`\n",
|
||
"* `ni = thread_idx % n`\n",
|
||
"\n",
|
||
"In native version, each thread process 1 element, in this case, `mi` and `ni` is logical\n",
|
||
"coordinate into data tensor `mA`, `mB` and `mC`.\n",
|
||
"\n",
|
||
"Int vectorized version, each thread process multiple values of input and output tensor.\n",
|
||
"logical coordinate should be computed with both thread and value index.\n",
|
||
"\n",
|
||
"* `thread_idx // n`\n",
|
||
"* `(thread_idx % n) * 4 + value_idx`\n",
|
||
"\n",
|
||
"\n",
|
||
"Step 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n",
|
||
"\n",
|
||
"* Vectorized Load\n",
|
||
"\n",
|
||
"```python\n",
|
||
" frgA = gA[(None, (mi, ni))].load()\n",
|
||
"```\n",
|
||
"\n",
|
||
"* Elementwise Load (less efficient)\n",
|
||
"\n",
|
||
"```python\n",
|
||
" frgA0 = mA[(mi, ni * 4 + 0)]\n",
|
||
" frgA1 = mA[(mi, ni * 4 + 1)]\n",
|
||
" frgA2 = mA[(mi, ni * 4 + 2)]\n",
|
||
" frgA3 = mA[(mi, ni * 4 + 3)]\n",
|
||
"\n",
|
||
" # Or use divided layout\n",
|
||
"\n",
|
||
" frgA0 = gA[(0, (mi, ni))]\n",
|
||
" frgA1 = gA[(1, (mi, ni))]\n",
|
||
" frgA2 = gA[(2, (mi, ni))]\n",
|
||
" frgA3 = gA[(3, (mi, ni))]\n",
|
||
"```\n",
|
||
"\n",
|
||
"CuTe introduces TV layout to represent this mapping from thread index and value index\n",
|
||
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
|
||
"By configuring different TV layouts, we can experiment with different memory access\n",
|
||
"patterns with minimal code changes.\n",
|
||
"\n",
|
||
"**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \n",
|
||
"to logical coordinate of tensor. \n",
|
||
"\n",
|
||
"We always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n",
|
||
"\n",
|
||
"With *TV Layout*, each thread can find logical coordinates or indices of data partitioned\n",
|
||
"to current thread.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Elementwise with TV Layout\n",
|
||
"\n",
|
||
"In this example, we rewrite elementwise kernel with two levels of tiling: \n",
|
||
"* the thread-block level \n",
|
||
"* the thread level with TV Layout and tiling\n",
|
||
"\n",
|
||
"For thread-block level tiling, each input & output tensor is first divided\n",
|
||
"into a group of ``(TileM, TileN)`` sub-tensors at the host side. Please be noticed that\n",
|
||
"in this case, we still use `zipped_divide` but for tiling at thread-block level.\n",
|
||
"\n",
|
||
"Inside the GPU kernel, we slice tiled tensor with the thread-block index at the 2nd mode \n",
|
||
"as ``gA[((None, None), bidx)]``, which returns a thread-block local view of\n",
|
||
"a single ``(TileM, TileN)`` sub-tensor. This sub-tensor maps logical coordinates\n",
|
||
"inside ``(TileM, TileN)`` to physical address of elements.\n",
|
||
"\n",
|
||
"At thread level tiling, we compose the above sub-tensor (logical coordinates to physical addresses) \n",
|
||
"with the TV layout (thread & value indices to logical coordinates). This gives us a tiled sub-tensor \n",
|
||
"that maps from thread & value indices directly to physical addresses.\n",
|
||
"\n",
|
||
"We then slice it with the thread index as ``tidfrgA[(tidx, None)]`` to get \n",
|
||
"a thread-local view of the data each thread accesses. Note that the thread index\n",
|
||
"is now in the 1st mode, as TV layout is normally have form ``(thread_domain, value_domain):(...)``.\n",
|
||
"\n",
|
||
"### Kernel Code"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.kernel\n",
|
||
"def elementwise_add_kernel(\n",
|
||
" gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, tv_layout: cute.Layout\n",
|
||
"):\n",
|
||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||
"\n",
|
||
" # --------------------------------\n",
|
||
" # slice for thread-block level view\n",
|
||
" # --------------------------------\n",
|
||
" blk_coord = ((None, None), bidx)\n",
|
||
"\n",
|
||
" # logical coord -> address\n",
|
||
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
|
||
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
|
||
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
|
||
"\n",
|
||
" # --------------------------------\n",
|
||
" # compose for thread-index & value-index to physical mapping\n",
|
||
" # --------------------------------\n",
|
||
" # blockA: (TileM, TileN) -> physical address\n",
|
||
" # tv_layout: (tid, vid) -> (TileM, TileN)\n",
|
||
" # tidfrgA = blkA o tv_layout\n",
|
||
" # tidfrgA: (tid, vid) -> physical address\n",
|
||
" tidfrgA = cute.composition(blkA, tv_layout)\n",
|
||
" tidfrgB = cute.composition(blkB, tv_layout)\n",
|
||
" tidfrgC = cute.composition(blkC, tv_layout)\n",
|
||
"\n",
|
||
" print(\"Composed with TV layout:\")\n",
|
||
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
|
||
"\n",
|
||
" # --------------------------------\n",
|
||
" # slice for thread-level view\n",
|
||
" # --------------------------------\n",
|
||
" # `None` represent slice of the entire per-thread data\n",
|
||
" thr_coord = (tidx, None)\n",
|
||
" # thr_coord = (tidx, cute.repeat_like(None, gA.shape[1]))\n",
|
||
"\n",
|
||
" # slice for threads: vid -> address\n",
|
||
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
|
||
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
|
||
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
|
||
"\n",
|
||
" thrC[None] = thrA.load() + thrB.load()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Host Code\n",
|
||
"\n",
|
||
"The host code below shows the construction of the TV layout. By composing\n",
|
||
"a thread layout of ``(4,64):(64,1)`` (64 threads read contiguous elements on the row dimension,\n",
|
||
"then 64-thread-groups(2 warps) read different rows) with a value layout of ``(16,8):(8,1)`` (each thread reads\n",
|
||
"8 contiguous 16b elements on the row dimension across 4 contiguous rows).\n",
|
||
"\n",
|
||
"In order to generalize, we started with byte-layout to describe layout for elements in bytes. This is\n",
|
||
"to ensure use of 128bit vectorized load store. Then we leverage ``recast_layout`` to convert into\n",
|
||
"element-layout.\n",
|
||
"\n",
|
||
"```python\n",
|
||
" # src type bits: 8\n",
|
||
" # dst type bits: bits of element type\n",
|
||
" val_layout = cute.recast_layout(dtype.width, 8, bit_val_layout)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def elementwise_add(\n",
|
||
" mA: cute.Tensor,\n",
|
||
" mB: cute.Tensor,\n",
|
||
" mC: cute.Tensor,\n",
|
||
"):\n",
|
||
" # mA layout: (M, N):(N, 1)\n",
|
||
" # TV layout map thread & value index to (64, 512) logical tile\n",
|
||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||
" # mode-1 for coalesced load-store\n",
|
||
" # - each thread load contiguous 16 bytes each row and load 16 rows\n",
|
||
" coalesced_ldst_bytes = 16\n",
|
||
"\n",
|
||
" # Compile time validation: expect same element type for all input tensors\n",
|
||
" assert all(t.element_type == mA.element_type for t in [mA, mB, mC])\n",
|
||
" dtype = mA.element_type\n",
|
||
"\n",
|
||
" thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n",
|
||
" val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n",
|
||
" val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n",
|
||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||
"\n",
|
||
" print(f\"[DSL INFO] Tiler: {tiler_mn}\")\n",
|
||
" print(f\"[DSL INFO] TV Layout: {tv_layout}\")\n",
|
||
"\n",
|
||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
"\n",
|
||
" print(\"Tiled Input Tensors:\")\n",
|
||
" print(\"[DSL INFO] Tiled Tensors:\")\n",
|
||
" print(f\"[DSL INFO] gA = {gA.type}\")\n",
|
||
" print(f\"[DSL INFO] gB = {gB.type}\")\n",
|
||
" print(f\"[DSL INFO] gC = {gC.type}\")\n",
|
||
"\n",
|
||
" # Launch the kernel asynchronously\n",
|
||
" # Async token(s) can also be specified as dependencies\n",
|
||
" elementwise_add_kernel(gA, gB, gC, tv_layout).launch(\n",
|
||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"\n",
|
||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||
"\n",
|
||
"elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
|
||
"elementwise_add_(a_, b_, c_)\n",
|
||
"\n",
|
||
"# verify correctness\n",
|
||
"torch.testing.assert_close(c, a + b)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Explanation of Layouts\n",
|
||
"\n",
|
||
"Let's take a closer look using zipped divided input tensor `gA` as an example.\n",
|
||
"We also choose a smaller M/N, `(256,512)`, to make it easier to explain and visualize.\n",
|
||
"\n",
|
||
"```\n",
|
||
"Tiled to Thread Block:\n",
|
||
"\n",
|
||
" ((16,256),(16,2)) : ((512,1),(8192,256))\n",
|
||
" ~~~~~~~~ ~~~~~~ ~~~~~\n",
|
||
" | | |\n",
|
||
" | | |\n",
|
||
" | `-----------------------> Number of Thread Blocks\n",
|
||
" | |\n",
|
||
" | |\n",
|
||
" `-------------------'\n",
|
||
" |\n",
|
||
" V\n",
|
||
" Thread Block\n",
|
||
" Tile\n",
|
||
"\n",
|
||
"Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n",
|
||
"\n",
|
||
" (16,256) : (512,1)\n",
|
||
" ~~~~~~ ~~~~~~\n",
|
||
" | | Tiled/Composed with TV Layout\n",
|
||
" | |\n",
|
||
" | | o ((32,4),(8,4)):((128,4),(16,1))\n",
|
||
" V V\n",
|
||
"~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~\n",
|
||
"((32,4),(8,4)) : ((8,2048),(1,512))\n",
|
||
" | |\n",
|
||
" | `--------> per thread fragment\n",
|
||
" |\n",
|
||
"Thread Block\n",
|
||
" Shape\n",
|
||
"\n",
|
||
"Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Visualization of TV Layout\n",
|
||
"\n",
|
||
"To visualize TV Layout, we can first install *`cute-viz`*\n",
|
||
"\n",
|
||
"```\n",
|
||
"pip install -U git+https://github.com/NTT123/cute-viz.git\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"try:\n",
|
||
" from cute_viz import display_tv_layout\n",
|
||
"\n",
|
||
" @cute.jit\n",
|
||
" def visualize():\n",
|
||
" # Create and render a layout to file\n",
|
||
" # layout = cute.make_layout( ((16,16),(256,2)), stride=((512,8192),(1,256)))\n",
|
||
" # display_layout(layout)\n",
|
||
"\n",
|
||
" tv_layout = cute.make_layout(((32, 4), (8, 4)), stride=((128, 4), (16, 1)))\n",
|
||
" display_tv_layout(tv_layout, (16, 256))\n",
|
||
"\n",
|
||
" thr_block_layout = cute.make_layout((16, 256), stride=(512, 1))\n",
|
||
" print(cute.composition(thr_block_layout, tv_layout))\n",
|
||
"\n",
|
||
" visualize()\n",
|
||
"except ImportError:\n",
|
||
" pass"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"***Why modes of thread domain of TV Layout looks swapped especially when tensor is row major?***\n",
|
||
"\n",
|
||
"We may notice that *TV Layout* in above example is `((32,4),(8,4)):((128,4),(16,1))`. \n",
|
||
"However, on visualization, thread indices are arrange as shape `(4,32)` rather than \n",
|
||
"`(32,4)` of *TV Layout*.\n",
|
||
"\n",
|
||
"This is a commonly asked question by developers from both internal teams and community.\n",
|
||
"\n",
|
||
"It's important to keep in mind that *TV Layout* maps `(thread_index, value_index)` to \n",
|
||
"`(row_index, column_index)` of logical domain `(TileM, TileN)`. However, visualization \n",
|
||
"shows **inverse** mapping of logical domain `(TileM, TileN)` to `(thread_domain, value_domain)`,\n",
|
||
"because this is more intuitive for human developer.\n",
|
||
"\n",
|
||
"That's why the shape of domain of *TV Layout* doesn't necessarily match logical view."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"benchmark(elementwise_add_, a_, b_, c_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Remap/Transpose thread block index\n",
|
||
"\n",
|
||
"As tensors are row major in this example, we may want thread blocks to load contiguous memory as much as possible.\n",
|
||
"\n",
|
||
"We can apply a simple thread-block remapping to transpose the mapping of thread block indices in row first order. \n",
|
||
"`cute.composition(gA, (None, remap_block))` only apply transpose of 2nd mode of tiled layout but keep \n",
|
||
"the 1st mode un-touched.\n",
|
||
"\n",
|
||
"```python\n",
|
||
" remap_block = cute.make_ordered_layout(\n",
|
||
" cute.select(gA.shape[1], mode=[1, 0]), order=(1, 0)\n",
|
||
" )\n",
|
||
" gA = cute.composition(gA, (None, remap_block))\n",
|
||
" gB = cute.composition(gB, (None, remap_block))\n",
|
||
" gC = cute.composition(gC, (None, remap_block))\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def elementwise_add(\n",
|
||
" mA: cute.Tensor,\n",
|
||
" mB: cute.Tensor,\n",
|
||
" mC: cute.Tensor,\n",
|
||
"):\n",
|
||
" # mA layout: (M, N):(N, 1)\n",
|
||
" # TV layout map thread & value index to (64, 512) logical tile\n",
|
||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||
" # mode-1 for coalesced load-store\n",
|
||
" # - each thread load contiguous 16 bytes each row and load 16 rows\n",
|
||
" coalesced_ldst_bytes = 16\n",
|
||
"\n",
|
||
" # Compile time validation: expect same element type for all input tensors\n",
|
||
" assert all(t.element_type == mA.element_type for t in [mA, mB, mC])\n",
|
||
" dtype = mA.element_type\n",
|
||
"\n",
|
||
" thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n",
|
||
" val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n",
|
||
" val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n",
|
||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||
"\n",
|
||
" print(f\"[DSL INFO] Tiler: {tiler_mn}\")\n",
|
||
" print(f\"[DSL INFO] TV Layout: {tv_layout}\")\n",
|
||
"\n",
|
||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
"\n",
|
||
" # (RestM, RestN) -> (RestN, RestM)\n",
|
||
" remap_block = cute.make_ordered_layout(\n",
|
||
" cute.select(gA.shape[1], mode=[1, 0]), order=(1, 0)\n",
|
||
" )\n",
|
||
" gA = cute.composition(gA, (None, remap_block))\n",
|
||
" gB = cute.composition(gB, (None, remap_block))\n",
|
||
" gC = cute.composition(gC, (None, remap_block))\n",
|
||
"\n",
|
||
" print(\"Tiled Input Tensors:\")\n",
|
||
" print(\"[DSL INFO] Tiled Tensors:\")\n",
|
||
" print(f\"[DSL INFO] gA = {gA.type}\")\n",
|
||
" print(f\"[DSL INFO] gB = {gB.type}\")\n",
|
||
" print(f\"[DSL INFO] gC = {gC.type}\")\n",
|
||
"\n",
|
||
" # Launch the kernel asynchronously\n",
|
||
" # Async token(s) can also be specified as dependencies\n",
|
||
" elementwise_add_kernel(gA, gB, gC, tv_layout).launch(\n",
|
||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"\n",
|
||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||
"\n",
|
||
"elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
|
||
"elementwise_add_(a_, b_, c_)\n",
|
||
"\n",
|
||
"# verify correctness\n",
|
||
"torch.testing.assert_close(c, a + b)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"benchmark(compiled_func, a_, b_, c_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using Lambda Function\n",
|
||
"\n",
|
||
"CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n",
|
||
"E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n",
|
||
"\n",
|
||
"\n",
|
||
"```python\n",
|
||
"@cute.jit\n",
|
||
"def elementwise_apply(\n",
|
||
" op: cutlass.Constexpr,\n",
|
||
" inputs,\n",
|
||
" result: cute.Tensor\n",
|
||
"):\n",
|
||
" ...\n",
|
||
"\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@cute.kernel\n",
|
||
"def elementwise_apply_kernel(\n",
|
||
" op: cutlass.Constexpr,\n",
|
||
" mInputs: List[cute.Tensor],\n",
|
||
" mC: cute.Tensor,\n",
|
||
" cC: cute.Tensor, # coordinate tensor\n",
|
||
" shape: cute.Shape,\n",
|
||
" tv_layout: cute.Layout, # (tid, vid) -> logic coord\n",
|
||
"):\n",
|
||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||
"\n",
|
||
" ###############################################################################\n",
|
||
" # Slice to local tile of thread block\n",
|
||
" ###############################################################################\n",
|
||
" blk_crd = ((None, None), bidx)\n",
|
||
"\n",
|
||
" # Leverage the meta-programming capability of the DSL to slice the tensors for each input\n",
|
||
" # All for loops below on input tensors would be fully unrolled automatically at compile time\n",
|
||
" # logical coord -> memory address\n",
|
||
" gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN)\n",
|
||
" gC = mC[blk_crd] # (TileM, TileN)\n",
|
||
" gCrd = cC[blk_crd] # (TileM, TileN)\n",
|
||
"\n",
|
||
" print(\"[DSL INFO] Sliced Tensors per thread block:\")\n",
|
||
" for i in cutlass.range_constexpr(len(gInputs)):\n",
|
||
" print(f\"[DSL INFO] ctaInputs{i} = {gInputs[i].type}\")\n",
|
||
" print(f\"[DSL INFO] gC = {gC.type}\")\n",
|
||
" print(f\"[DSL INFO] gCrd = {gCrd.type}\")\n",
|
||
"\n",
|
||
" ###############################################################################\n",
|
||
" # Compose with thread block TV layout to map thread & value indices to memory address\n",
|
||
" ###############################################################################\n",
|
||
" # (tid, vid) -> memory address\n",
|
||
" tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs]\n",
|
||
" tidfrgC = cute.composition(gC, tv_layout)\n",
|
||
" tidfrgCrd = cute.composition(gCrd, tv_layout)\n",
|
||
"\n",
|
||
" # repeat None like vid to remove hierarchy of layout\n",
|
||
" thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1]))\n",
|
||
"\n",
|
||
" ###############################################################################\n",
|
||
" # Slice to local tile of thread\n",
|
||
" ###############################################################################\n",
|
||
" # vid -> address\n",
|
||
" thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V)\n",
|
||
" thrC = tidfrgC[thr_crd] # (V)\n",
|
||
" thrCrd = tidfrgCrd[thr_crd]\n",
|
||
"\n",
|
||
" print(\"[DSL INFO] Sliced Tensors per thread:\")\n",
|
||
" for i in cutlass.range_constexpr(len(thrInputs)):\n",
|
||
" print(f\"[DSL INFO] thrInputs{i} = {thrInputs[i].type}\")\n",
|
||
" print(f\"[DSL INFO] thrC = {thrC.type}\")\n",
|
||
" print(f\"[DSL INFO] thrCrd = {thrCrd.type}\")\n",
|
||
"\n",
|
||
" ###############################################################################\n",
|
||
" # Compute predicate for out of boundary checks\n",
|
||
" ###############################################################################\n",
|
||
" frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)\n",
|
||
" print(f\"[DSL INFO] frgPred = {frgPred.type}\")\n",
|
||
"\n",
|
||
" for i in cutlass.range_constexpr(cute.size(frgPred)):\n",
|
||
" frgPred[i] = cute.elem_less(thrCrd[i], shape)\n",
|
||
"\n",
|
||
" # if tidx == 0 and bidx == 0:\n",
|
||
" # cute.print_tensor(frgPred)\n",
|
||
"\n",
|
||
" ##########################################################\n",
|
||
" # Load data and compute result\n",
|
||
" ##########################################################\n",
|
||
"\n",
|
||
" # Load data before use. The compiler will optimize the copy and load\n",
|
||
" # operations to convert some memory ld/st into register uses.\n",
|
||
" result = op(*[thrInput.load() for thrInput in thrInputs])\n",
|
||
" thrC.store(result)\n",
|
||
"\n",
|
||
"\n",
|
||
"@cute.jit\n",
|
||
"def elementwise_apply(op: cutlass.Constexpr, inputs, result: cute.Tensor):\n",
|
||
" # Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type\n",
|
||
" coalesced_ldst_bytes = 16\n",
|
||
"\n",
|
||
" # Compile time validation: expect same element type for all input tensors\n",
|
||
" assert all(t.element_type == inputs[0].element_type for t in inputs)\n",
|
||
" dtype = inputs[0].element_type\n",
|
||
"\n",
|
||
" thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))\n",
|
||
" val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))\n",
|
||
" val_layout = cute.recast_layout(dtype.width, 8, val_layout)\n",
|
||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||
"\n",
|
||
" mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs]\n",
|
||
" mC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||
"\n",
|
||
" # (RestM, RestN) -> (RestN, RestM)\n",
|
||
" remap_block = cute.make_ordered_layout(\n",
|
||
" cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0)\n",
|
||
" )\n",
|
||
" for i, t in enumerate(mInputs):\n",
|
||
" mInputs[i] = cute.composition(t, (None, remap_block))\n",
|
||
"\n",
|
||
" mC = cute.composition(mC, (None, remap_block))\n",
|
||
"\n",
|
||
" idC = cute.make_identity_tensor(result.shape)\n",
|
||
" cC = cute.zipped_divide(idC, tiler=tiler_mn)\n",
|
||
"\n",
|
||
" # Launch the kernel asynchronously\n",
|
||
" # Group input tensors into a list as a single argument\n",
|
||
" elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch(\n",
|
||
" grid=[cute.size(mC, mode=[1]), 1, 1],\n",
|
||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||
"\n",
|
||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||
"c_ = from_dlpack(c, assumed_align=16)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from operator import mul\n",
|
||
"\n",
|
||
"elementwise_apply(mul, [a_, b_], c_)\n",
|
||
"\n",
|
||
"# verify correctness\n",
|
||
"torch.testing.assert_close(c, mul(a, b))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Use customized function\n",
|
||
"\n",
|
||
"Custom operators can be more complex. For example, here's a function that performs\n",
|
||
"multiplication followed by ReLU:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def mul_relu(a, b):\n",
|
||
" tmp = a * b\n",
|
||
" return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n",
|
||
"\n",
|
||
"\n",
|
||
"# As we uses cute.where in customized operation, we need to create another relu function\n",
|
||
"def mul_relu_ref(a, b):\n",
|
||
" tmp = a * b\n",
|
||
" return torch.relu(tmp)\n",
|
||
"\n",
|
||
"\n",
|
||
"elementwise_apply(mul_relu, [a_, b_], c_)\n",
|
||
"\n",
|
||
"# verify correctness\n",
|
||
"torch.testing.assert_close(c, mul_relu_ref(a, b))"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv3_12",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.10"
|
||
},
|
||
"widgets": {
|
||
"application/vnd.jupyter.widget-state+json": {
|
||
"state": {},
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|