Files
cutlass/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb
2025-05-15 09:38:27 -04:00

839 lines
29 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"from functools import partial\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"from cutlass.cute.runtime import from_dlpack"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial: Elementwise Add Kernel in CuTe DSL\n",
"\n",
"This tutorial demonstrates how to implement a simple elementwise\n",
"addition kernel using the CuTe DSL (Domain Specific Language).\n",
"\n",
"\n",
"\n",
"Elementwise Addition\n",
"---------------------\n",
"\n",
"Elementwise addition is a fundamental operation in linear algebra.\n",
"Given two tensors of the same shape, the operation performs element-wise\n",
"addition to produce a result tensor of the same shape.\n",
"\n",
"For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n",
"the elementwise addition operation :math:`C = A + B` is defined as:\n",
"\n",
"$\n",
" C_{i,j} = A_{i,j} + B_{i,j}\n",
"$\n",
"\n",
"where:\n",
"\n",
"- $i \\in [0, M-1]$ represents the row index\n",
"- $j \\in [0, N-1]$ represents the column index\n",
"- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n",
" in tensors $A$, $B$, and $C$ respectively\n",
"\n",
"This operation is performed independently for each element position,\n",
"making it highly parallelizable and well-suited for GPU implementation.\n",
"\n",
"Naive Elementwise Add Kernel\n",
"-----------------------------\n",
"\n",
"Let's start with a naive implementation that loads each element from\n",
"$A$ and $B$, adds them, and stores the result back to $C$."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def naive_elementwise_add_kernel(\n",
" gA: cute.Tensor,\n",
" gB: cute.Tensor,\n",
" gC: cute.Tensor,\n",
"):\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" bidx, _, _ = cute.arch.block_idx()\n",
" bdim, _, _ = cute.arch.block_dim()\n",
"\n",
" thread_idx = bidx * bdim + tidx\n",
"\n",
" # Map thread index to logical index of input tensor\n",
" m, n = gA.shape\n",
" ni = thread_idx % n\n",
" mi = thread_idx // n\n",
"\n",
" # Map logical index to physical address via tensor layout\n",
" a_val = gA[mi, ni]\n",
" b_val = gB[mi, ni]\n",
"\n",
" # Perform element-wise addition\n",
" gC[mi, ni] = a_val + b_val"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structure of the Kernel\n",
"\n",
"The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n",
"In this kernel, we don't use CuTe layout algebra but only use basic\n",
"addressing to index the tensor.\n",
"\n",
"We can launch the kernel with the following JIT function:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"@cute.jit\n",
"def naive_elementwise_add(\n",
" mA: cute.Tensor,\n",
" mB: cute.Tensor,\n",
" mC: cute.Tensor\n",
"):\n",
" num_threads_per_block = 256\n",
"\n",
" m, n = mA.shape\n",
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
" kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n",
" block=(num_threads_per_block, 1, 1))\n",
"\n",
"M, N = 2048, 2048\n",
"\n",
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"# Compile kernel\n",
"naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n",
"naive_elementwise_add_(a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, a + b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Benchmark performance\n",
"\n",
"Here's a utility function to benchmark our kernel implementations:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def benchmark(callable, *, num_warmups, num_iterations):\n",
" start_event = torch.cuda.Event(enable_timing=True)\n",
" end_event = torch.cuda.Event(enable_timing=True)\n",
"\n",
" torch.cuda.synchronize()\n",
"\n",
" for _ in range(num_warmups):\n",
" callable()\n",
"\n",
" start_event.record(stream=torch.cuda.current_stream())\n",
" for _ in range(num_iterations):\n",
" callable()\n",
" end_event.record(stream=torch.cuda.current_stream())\n",
" torch.cuda.synchronize()\n",
"\n",
" elapsed_time = start_event.elapsed_time(end_event)\n",
" avg_time = elapsed_time / num_iterations\n",
"\n",
" print(f\"Average execution time: {avg_time:.4f} ms\")\n",
" print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average execution time: 0.0385 ms\n",
"Throughput: 653.44 GB/s\n"
]
}
],
"source": [
"benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Performance Analysis\n",
"\n",
"While our naive implementation maps thread indices to contiguous tensor\n",
"dimensions for coalesced memory access, it doesn't have enough\n",
"in-flight load & store operations to hide memory latency.\n",
"\n",
"According to Little's Law:\n",
"\n",
"$ L = \\lambda \\times W $\n",
"\n",
"Where:\n",
"- $L$ is the average number of items in a system\n",
"- $\\lambda$ is the average arrival rate of items (bandwidth)\n",
"- $W$ is the average time an item spends in the system (latency)\n",
"\n",
"For our elementwise addition kernel:\n",
"\n",
"1. $L$: The number of load & store operations in-flight\n",
"2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n",
"3. $W$ (Latency): Round-trip delay of memory requests\n",
"\n",
"For memory-bound operations like elementwise addition, performance is\n",
"limited by the number of in-flight load & store operations.\n",
"\n",
"## Vectorized Load and Store\n",
"\n",
"To improve performance according to Little's Law, we need to increase the number\n",
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
"in each load & store operation per thread through vectorized memory access.\n",
"\n",
"Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n",
"we can load 4 elements per vectorized operation on contiguous rows.\n",
"CuTe tiling operations make this vectorization straightforward.\n",
"\n",
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
"as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n",
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
"\n",
"```python\n",
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
"gA = cute.zipped_divide(a, tiler=(1, 4)) # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n",
"```\n",
"\n",
"$\n",
" \\begin{array}{ccccc}\n",
" & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n",
" & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n",
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
" \\end{array}\n",
"$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def vectorized_elementwise_add_kernel(\n",
" gA: cute.Tensor,\n",
" gB: cute.Tensor,\n",
" gC: cute.Tensor,\n",
"):\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" bidx, _, _ = cute.arch.block_idx()\n",
" bdim, _, _ = cute.arch.block_dim()\n",
"\n",
" thread_idx = bidx * bdim + tidx\n",
"\n",
" # Map thread index to logical index of input tensor\n",
" m, n = gA.shape[1] # thread-domain\n",
" ni = thread_idx % n\n",
" mi = thread_idx // n\n",
"\n",
" # Map logical index to physical address via tensor layout\n",
" a_val = gA[(None, (mi, ni))].load()\n",
" b_val = gB[(None, (mi, ni))].load()\n",
" print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n",
" print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n",
"\n",
" # Perform element-wise addition\n",
" gC[(None, (mi, ni))] = a_val + b_val"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
"we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n",
"\n",
"```python\n",
"gA[(None, (mi, ni))]\n",
"\n",
"```\n",
"\n",
"Then tensor data can be loaded into vector via the `.load()` method.\n",
"\n",
"\n",
"```\n",
" slice\n",
" ((1,4),(2048,512)):((0,1),(2048,4)) ==> ((1,4)):((0,1))\n",
" ^ ^ ^\n",
" | | |\n",
" (None, (mi, ni))\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[DSL INFO] Tiled Tensors:\n",
"[DSL INFO] gA = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
"[DSL INFO] gB = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
"[DSL INFO] gC = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
"[DSL INFO] sliced gA = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n",
"[DSL INFO] sliced gB = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n"
]
}
],
"source": [
"@cute.jit\n",
"def vectorized_elementwise_add(\n",
" mA: cute.Tensor,\n",
" mB: cute.Tensor,\n",
" mC: cute.Tensor\n",
"):\n",
" threads_per_block = 256\n",
"\n",
" gA = cute.zipped_divide(mA, (1, 4))\n",
" gB = cute.zipped_divide(mB, (1, 4))\n",
" gC = cute.zipped_divide(mC, (1, 4))\n",
"\n",
" print(f\"[DSL INFO] Tiled Tensors:\")\n",
" print(f\"[DSL INFO] gA = {gA}\")\n",
" print(f\"[DSL INFO] gB = {gB}\")\n",
" print(f\"[DSL INFO] gC = {gC}\")\n",
"\n",
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
" block=(threads_per_block, 1, 1),\n",
" )\n",
"\n",
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
"compiled_func(a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, a + b)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average execution time: 0.0202 ms\n",
"Throughput: 1244.98 GB/s\n"
]
}
],
"source": [
"benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TV Layout\n",
"\n",
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
"to physical addresses:\n",
"\n",
"Step 1: Map thread index to logical M/N coordinates\n",
"\n",
"```python\n",
" mi = thread_idx // n\n",
" ni = thread_idx % n\n",
"```\n",
"\n",
"Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n",
"\n",
"```python\n",
" a[(None, (mi, ni))].load()\n",
"```\n",
"\n",
"CuTe uses TV layout to represent this mapping from thread index and value index\n",
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
"By configuring different TV layouts, we can experiment with different memory access\n",
"patterns with minimal code changes.\n",
"\n",
"The following example demonstrates two levels of tiling: at the thread-block level\n",
"and at the thread level.\n",
"\n",
"For thread-block level tiling, each input & output tensor is first divided\n",
"into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n",
"\n",
"Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n",
"(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n",
"a single ``(TileM, TileN)`` sub-tensor.\n",
"\n",
"For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n",
"to physical addresses) with the TV layout (which maps from thread & value indices to\n",
"logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n",
"indices directly to physical addresses.\n",
"\n",
"We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n",
"to get a thread-local view of the data each thread accesses. Note that the thread index\n",
"is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def elementwise_add_kernel(\n",
" gA: cute.Tensor,\n",
" gB: cute.Tensor,\n",
" gC: cute.Tensor,\n",
" tv_layout: cute.Layout\n",
"):\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" bidx, _, _ = cute.arch.block_idx()\n",
"\n",
" #--------------------------------\n",
" # slice for thread-block level view\n",
" #--------------------------------\n",
" blk_coord = ((None, None), bidx)\n",
"\n",
" # logical coord -> address\n",
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
"\n",
" #--------------------------------\n",
" # compose for thread-index & value-index to physical mapping\n",
" #--------------------------------\n",
" # blockA: (TileM, TileN) -> physical address\n",
" # tv_layout: (tid, vid) -> (TileM, TileN)\n",
" # tidfrgA = blkA o tv_layout\n",
" # tidfrgA: (tid, vid) -> physical address\n",
" tidfrgA = cute.composition(blkA, tv_layout)\n",
" tidfrgB = cute.composition(blkB, tv_layout)\n",
" tidfrgC = cute.composition(blkC, tv_layout)\n",
"\n",
" print(f\"Composed with TV layout:\")\n",
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
"\n",
" #--------------------------------\n",
" # slice for thread-level view\n",
" #--------------------------------\n",
" # `None` represent slice of the entire per-thread data\n",
" thr_coord = (tidx, None)\n",
"\n",
" # slice for threads: vid -> address\n",
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
"\n",
" thrC[None] = thrA.load() + thrB.load()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we take a closer look at the layout of zipped divided input tensor `gA`:\n",
"\n",
"```\n",
"Tiled to Thread Block:\n",
"\n",
" ((16,256),(128,8)) : ((2048,1),(32768,256))\n",
" ~~~~~~~~ ~~~~~~ ~~~~~~~~\n",
" | | |\n",
" | | |\n",
" | `------------------------> Number of Thread Blocks\n",
" | |\n",
" | |\n",
" `--------------------'\n",
" |\n",
" V\n",
" Thread Block\n",
" Tile\n",
"\n",
"Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n",
"\n",
" (16,256) : (2048,1)\n",
" ~~~~~~ ~~~~~~\n",
" | | Tiled/Composed with TV Layout\n",
" | | \n",
" | | o ((32,4),(8,4)):((128,4),(16,1))\n",
" V V \n",
"~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~ \n",
"((32,4), (8,4)) : ((4,8192),(1,2048))\n",
" | |\n",
" | `--------> per thread fragment\n",
" |\n",
"Thread Block\n",
" Shape\n",
"\n",
"Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n",
"\n",
"```\n",
"\n",
"The host code below shows the construction of the TV layout. By composing\n",
"a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n",
"then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n",
"8 contiguous elements on the row dimension across 4 contiguous rows),\n",
"we obtain the TV layout shown in the figure above."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tiler: (16, 256)\n",
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
"Tiled Input Tensors:\n",
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
"Composed with TV layout:\n",
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
]
}
],
"source": [
"@cute.jit\n",
"def elementwise_add(\n",
" mA: cute.Tensor,\n",
" mB: cute.Tensor,\n",
" mC: cute.Tensor,\n",
"):\n",
" # mA layout: (M, N):(N, 1)\n",
" # TV layout map thread & value index to (16, 256) logical tile\n",
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
" # mode-1 for coalesced load-store\n",
" # - each thread load 8 contiguous element each row and load 4 rows\n",
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
" print(f\"Tiler: {tiler_mn}\")\n",
" print(f\"TV Layout: {tv_layout}\")\n",
"\n",
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
"\n",
" print(f\"Tiled Input Tensors:\")\n",
" print(f\" gA: {gA.type}\")\n",
" print(f\" gB: {gB.type}\")\n",
" print(f\" gC: {gC.type}\")\n",
"\n",
" # Launch the kernel asynchronously\n",
" # Async token(s) can also be specified as dependencies\n",
" elementwise_add_kernel(\n",
" gA, gB, gC, tv_layout\n",
" ).launch(\n",
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
" )\n",
"\n",
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
"elementwise_add_(a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, a + b)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average execution time: 0.0222 ms\n",
"Throughput: 1133.58 GB/s\n"
]
}
],
"source": [
"benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using Lambda Function\n",
"\n",
"CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n",
"E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n",
"\n",
"\n",
"```python\n",
"@cute.jit\n",
"def elementwise_apply(\n",
" op: cutlass.Constexpr,\n",
" mA: cute.Tensor,\n",
" mB: cute.Tensor,\n",
" mC: cute.Tensor\n",
"):\n",
" ...\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tiler: (16, 256)\n",
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
"Tiled Input Tensors:\n",
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
"Composed with TV layout:\n",
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
]
}
],
"source": [
"@cute.kernel\n",
"def elementwise_apply_kernel(\n",
" op: cutlass.Constexpr, # lambda function must be const expr to generate code at compile time\n",
" gA: cute.Tensor,\n",
" gB: cute.Tensor,\n",
" gC: cute.Tensor,\n",
" tv_layout: cute.Layout\n",
"):\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" bidx, _, _ = cute.arch.block_idx()\n",
"\n",
" blk_coord = ((None, None), bidx)\n",
"\n",
" # logical coord -> address\n",
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
"\n",
" tidfrgA = cute.composition(blkA, tv_layout)\n",
" tidfrgB = cute.composition(blkB, tv_layout)\n",
" tidfrgC = cute.composition(blkC, tv_layout)\n",
"\n",
" print(f\"Composed with TV layout:\")\n",
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
"\n",
" thr_coord = (tidx, None)\n",
"\n",
" # slice for threads: vid -> address\n",
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
"\n",
" #--------------------------------\n",
" # apply custom operation\n",
" #--------------------------------\n",
" thrC[None] = op(thrA.load(), thrB.load())\n",
"\n",
"\n",
"@cute.jit\n",
"def elementwise_op(\n",
" op: cutlass.Constexpr,\n",
" mA: cute.Tensor,\n",
" mB: cute.Tensor,\n",
" mC: cute.Tensor,\n",
"):\n",
" # mA layout: (M, N):(N, 1)\n",
" # TV layout map thread & value index to (16, 256) logical tile\n",
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
" # mode-1 for coalesced load-store\n",
" # - each thread load 8 contiguous element each row and load 4 rows\n",
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
" print(f\"Tiler: {tiler_mn}\")\n",
" print(f\"TV Layout: {tv_layout}\")\n",
"\n",
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
"\n",
" print(f\"Tiled Input Tensors:\")\n",
" print(f\" gA: {gA.type}\")\n",
" print(f\" gB: {gB.type}\")\n",
" print(f\" gC: {gC.type}\")\n",
"\n",
" # Launch the kernel asynchronously\n",
" # Async token(s) can also be specified as dependencies\n",
" elementwise_apply_kernel(\n",
" op, gA, gB, gC, tv_layout\n",
" ).launch(\n",
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
" )\n",
"\n",
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"from operator import mul\n",
"\n",
"elementwise_op(mul, a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, mul(a, b))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Custom operators can be more complex. For example, here's a function that performs\n",
"multiplication followed by ReLU:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tiler: (16, 256)\n",
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
"Tiled Input Tensors:\n",
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
"Composed with TV layout:\n",
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
]
}
],
"source": [
"def mul_relu(a, b):\n",
" tmp = a * b\n",
" return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n",
"\n",
"\n",
"# As we uses cute.where in customized operation, we need to create another relu function\n",
"def mul_relu_ref(a, b):\n",
" tmp = a * b\n",
" return torch.relu(tmp)\n",
"\n",
"\n",
"elementwise_op(mul_relu, a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, mul_relu_ref(a, b))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}