{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from functools import partial\n", "\n", "import cutlass\n", "import cutlass.cute as cute\n", "from cutlass.cute.runtime import from_dlpack" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial: Elementwise Add Kernel in CuTe DSL\n", "\n", "This tutorial demonstrates how to implement a simple elementwise\n", "addition kernel using the CuTe DSL (Domain Specific Language).\n", "\n", "\n", "\n", "Elementwise Addition\n", "---------------------\n", "\n", "Elementwise addition is a fundamental operation in linear algebra.\n", "Given two tensors of the same shape, the operation performs element-wise\n", "addition to produce a result tensor of the same shape.\n", "\n", "For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n", "the elementwise addition operation :math:`C = A + B` is defined as:\n", "\n", "$\n", " C_{i,j} = A_{i,j} + B_{i,j}\n", "$\n", "\n", "where:\n", "\n", "- $i \\in [0, M-1]$ represents the row index\n", "- $j \\in [0, N-1]$ represents the column index\n", "- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n", " in tensors $A$, $B$, and $C$ respectively\n", "\n", "This operation is performed independently for each element position,\n", "making it highly parallelizable and well-suited for GPU implementation.\n", "\n", "Naive Elementwise Add Kernel\n", "-----------------------------\n", "\n", "Let's start with a naive implementation that loads each element from\n", "$A$ and $B$, adds them, and stores the result back to $C$." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "@cute.kernel\n", "def naive_elementwise_add_kernel(\n", " gA: cute.Tensor,\n", " gB: cute.Tensor,\n", " gC: cute.Tensor,\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", " bdim, _, _ = cute.arch.block_dim()\n", "\n", " thread_idx = bidx * bdim + tidx\n", "\n", " # Map thread index to logical index of input tensor\n", " m, n = gA.shape\n", " ni = thread_idx % n\n", " mi = thread_idx // n\n", "\n", " # Map logical index to physical address via tensor layout\n", " a_val = gA[mi, ni]\n", " b_val = gB[mi, ni]\n", "\n", " # Perform element-wise addition\n", " gC[mi, ni] = a_val + b_val" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Structure of the Kernel\n", "\n", "The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n", "In this kernel, we don't use CuTe layout algebra but only use basic\n", "addressing to index the tensor.\n", "\n", "We can launch the kernel with the following JIT function:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@cute.jit\n", "def naive_elementwise_add(\n", " mA: cute.Tensor,\n", " mB: cute.Tensor,\n", " mC: cute.Tensor\n", "):\n", " num_threads_per_block = 256\n", "\n", " m, n = mA.shape\n", " kernel = naive_elementwise_add_kernel(mA, mB, mC)\n", " kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n", " block=(num_threads_per_block, 1, 1))\n", "\n", "M, N = 2048, 2048\n", "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", "\n", "a_ = from_dlpack(a, assumed_align=16)\n", "b_ = from_dlpack(b, assumed_align=16)\n", "c_ = from_dlpack(c, assumed_align=16)\n", "\n", "# Compile kernel\n", "naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n", "naive_elementwise_add_(a_, b_, c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, a + b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Benchmark performance\n", "\n", "Here's a utility function to benchmark our kernel implementations:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def benchmark(callable, *, num_warmups, num_iterations):\n", " start_event = torch.cuda.Event(enable_timing=True)\n", " end_event = torch.cuda.Event(enable_timing=True)\n", "\n", " torch.cuda.synchronize()\n", "\n", " for _ in range(num_warmups):\n", " callable()\n", "\n", " start_event.record(stream=torch.cuda.current_stream())\n", " for _ in range(num_iterations):\n", " callable()\n", " end_event.record(stream=torch.cuda.current_stream())\n", " torch.cuda.synchronize()\n", "\n", " elapsed_time = start_event.elapsed_time(end_event)\n", " avg_time = elapsed_time / num_iterations\n", "\n", " print(f\"Average execution time: {avg_time:.4f} ms\")\n", " print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average execution time: 0.0385 ms\n", "Throughput: 653.44 GB/s\n" ] } ], "source": [ "benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performance Analysis\n", "\n", "While our naive implementation maps thread indices to contiguous tensor\n", "dimensions for coalesced memory access, it doesn't have enough\n", "in-flight load & store operations to hide memory latency.\n", "\n", "According to Little's Law:\n", "\n", "$ L = \\lambda \\times W $\n", "\n", "Where:\n", "- $L$ is the average number of items in a system\n", "- $\\lambda$ is the average arrival rate of items (bandwidth)\n", "- $W$ is the average time an item spends in the system (latency)\n", "\n", "For our elementwise addition kernel:\n", "\n", "1. $L$: The number of load & store operations in-flight\n", "2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n", "3. $W$ (Latency): Round-trip delay of memory requests\n", "\n", "For memory-bound operations like elementwise addition, performance is\n", "limited by the number of in-flight load & store operations.\n", "\n", "## Vectorized Load and Store\n", "\n", "To improve performance according to Little's Law, we need to increase the number\n", "of in-flight requests. We can do this by increasing the number of bytes handled\n", "in each load & store operation per thread through vectorized memory access.\n", "\n", "Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n", "we can load 4 elements per vectorized operation on contiguous rows.\n", "CuTe tiling operations make this vectorization straightforward.\n", "\n", "Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n", "``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n", "as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n", "Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n", "\n", "```python\n", "mA : cute.Tensor # (2048,2048):(2048,1)\n", "gA = cute.zipped_divide(a, tiler=(1, 4)) # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n", "```\n", "\n", "$\n", " \\begin{array}{ccccc}\n", " & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n", " & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n", " & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n", " \\end{array}\n", "$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "@cute.kernel\n", "def vectorized_elementwise_add_kernel(\n", " gA: cute.Tensor,\n", " gB: cute.Tensor,\n", " gC: cute.Tensor,\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", " bdim, _, _ = cute.arch.block_dim()\n", "\n", " thread_idx = bidx * bdim + tidx\n", "\n", " # Map thread index to logical index of input tensor\n", " m, n = gA.shape[1] # thread-domain\n", " ni = thread_idx % n\n", " mi = thread_idx // n\n", "\n", " # Map logical index to physical address via tensor layout\n", " a_val = gA[(None, (mi, ni))].load()\n", " b_val = gB[(None, (mi, ni))].load()\n", " print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n", " print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n", "\n", " # Perform element-wise addition\n", " gC[(None, (mi, ni))] = a_val + b_val" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n", "with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n", "we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n", "\n", "```python\n", "gA[(None, (mi, ni))]\n", "\n", "```\n", "\n", "Then tensor data can be loaded into vector via the `.load()` method.\n", "\n", "\n", "```\n", " slice\n", " ((1,4),(2048,512)):((0,1),(2048,4)) ==> ((1,4)):((0,1))\n", " ^ ^ ^\n", " | | |\n", " (None, (mi, ni))\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[DSL INFO] Tiled Tensors:\n", "[DSL INFO] gA = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", "[DSL INFO] gB = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", "[DSL INFO] gC = tensor> o ((1,4),(2048,512)):((0,1),(2048,4))>\n", "[DSL INFO] sliced gA = tensor> o ((1,4)):((0,1))>\n", "[DSL INFO] sliced gB = tensor> o ((1,4)):((0,1))>\n" ] } ], "source": [ "@cute.jit\n", "def vectorized_elementwise_add(\n", " mA: cute.Tensor,\n", " mB: cute.Tensor,\n", " mC: cute.Tensor\n", "):\n", " threads_per_block = 256\n", "\n", " gA = cute.zipped_divide(mA, (1, 4))\n", " gB = cute.zipped_divide(mB, (1, 4))\n", " gC = cute.zipped_divide(mC, (1, 4))\n", "\n", " print(f\"[DSL INFO] Tiled Tensors:\")\n", " print(f\"[DSL INFO] gA = {gA}\")\n", " print(f\"[DSL INFO] gB = {gB}\")\n", " print(f\"[DSL INFO] gC = {gC}\")\n", "\n", " vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n", " grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n", " block=(threads_per_block, 1, 1),\n", " )\n", "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", "\n", "a_ = from_dlpack(a, assumed_align=16)\n", "b_ = from_dlpack(b, assumed_align=16)\n", "c_ = from_dlpack(c, assumed_align=16)\n", "\n", "compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n", "compiled_func(a_, b_, c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, a + b)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average execution time: 0.0202 ms\n", "Throughput: 1244.98 GB/s\n" ] } ], "source": [ "benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TV Layout\n", "\n", "Both the naive and vectorized kernels follow a common pattern to map thread indices\n", "to physical addresses:\n", "\n", "Step 1: Map thread index to logical M/N coordinates\n", "\n", "```python\n", " mi = thread_idx // n\n", " ni = thread_idx % n\n", "```\n", "\n", "Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n", "\n", "```python\n", " a[(None, (mi, ni))].load()\n", "```\n", "\n", "CuTe uses TV layout to represent this mapping from thread index and value index\n", "(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n", "By configuring different TV layouts, we can experiment with different memory access\n", "patterns with minimal code changes.\n", "\n", "The following example demonstrates two levels of tiling: at the thread-block level\n", "and at the thread level.\n", "\n", "For thread-block level tiling, each input & output tensor is first divided\n", "into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n", "\n", "Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n", "(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n", "a single ``(TileM, TileN)`` sub-tensor.\n", "\n", "For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n", "to physical addresses) with the TV layout (which maps from thread & value indices to\n", "logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n", "indices directly to physical addresses.\n", "\n", "We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n", "to get a thread-local view of the data each thread accesses. Note that the thread index\n", "is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "@cute.kernel\n", "def elementwise_add_kernel(\n", " gA: cute.Tensor,\n", " gB: cute.Tensor,\n", " gC: cute.Tensor,\n", " tv_layout: cute.Layout\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", "\n", " #--------------------------------\n", " # slice for thread-block level view\n", " #--------------------------------\n", " blk_coord = ((None, None), bidx)\n", "\n", " # logical coord -> address\n", " blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n", " blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n", " blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n", "\n", " #--------------------------------\n", " # compose for thread-index & value-index to physical mapping\n", " #--------------------------------\n", " # blockA: (TileM, TileN) -> physical address\n", " # tv_layout: (tid, vid) -> (TileM, TileN)\n", " # tidfrgA = blkA o tv_layout\n", " # tidfrgA: (tid, vid) -> physical address\n", " tidfrgA = cute.composition(blkA, tv_layout)\n", " tidfrgB = cute.composition(blkB, tv_layout)\n", " tidfrgC = cute.composition(blkC, tv_layout)\n", "\n", " print(f\"Composed with TV layout:\")\n", " print(f\" tidfrgA: {tidfrgA.type}\")\n", "\n", " #--------------------------------\n", " # slice for thread-level view\n", " #--------------------------------\n", " # `None` represent slice of the entire per-thread data\n", " thr_coord = (tidx, None)\n", "\n", " # slice for threads: vid -> address\n", " thrA = tidfrgA[thr_coord] # (V) -> physical address\n", " thrB = tidfrgB[thr_coord] # (V) -> physical address\n", " thrC = tidfrgC[thr_coord] # (V) -> physical address\n", "\n", " thrC[None] = thrA.load() + thrB.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we take a closer look at the layout of zipped divided input tensor `gA`:\n", "\n", "```\n", "Tiled to Thread Block:\n", "\n", " ((16,256),(128,8)) : ((2048,1),(32768,256))\n", " ~~~~~~~~ ~~~~~~ ~~~~~~~~\n", " | | |\n", " | | |\n", " | `------------------------> Number of Thread Blocks\n", " | |\n", " | |\n", " `--------------------'\n", " |\n", " V\n", " Thread Block\n", " Tile\n", "\n", "Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n", "\n", " (16,256) : (2048,1)\n", " ~~~~~~ ~~~~~~\n", " | | Tiled/Composed with TV Layout\n", " | | \n", " | | o ((32,4),(8,4)):((128,4),(16,1))\n", " V V \n", "~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~ \n", "((32,4), (8,4)) : ((4,8192),(1,2048))\n", " | |\n", " | `--------> per thread fragment\n", " |\n", "Thread Block\n", " Shape\n", "\n", "Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n", "\n", "```\n", "\n", "The host code below shows the construction of the TV layout. By composing\n", "a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n", "then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n", "8 contiguous elements on the row dimension across 4 contiguous rows),\n", "we obtain the TV layout shown in the figure above." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tiler: (16, 256)\n", "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", "Tiled Input Tensors:\n", " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", "Composed with TV layout:\n", " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" ] } ], "source": [ "@cute.jit\n", "def elementwise_add(\n", " mA: cute.Tensor,\n", " mB: cute.Tensor,\n", " mC: cute.Tensor,\n", "):\n", " # mA layout: (M, N):(N, 1)\n", " # TV layout map thread & value index to (16, 256) logical tile\n", " # - contiguous thread index maps to mode-1 because input layout is contiguous on\n", " # mode-1 for coalesced load-store\n", " # - each thread load 8 contiguous element each row and load 4 rows\n", " thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n", " val_layout = cute.make_layout((4, 8), stride=(8, 1))\n", " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", " print(f\"Tiler: {tiler_mn}\")\n", " print(f\"TV Layout: {tv_layout}\")\n", "\n", " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", "\n", " print(f\"Tiled Input Tensors:\")\n", " print(f\" gA: {gA.type}\")\n", " print(f\" gB: {gB.type}\")\n", " print(f\" gC: {gC.type}\")\n", "\n", " # Launch the kernel asynchronously\n", " # Async token(s) can also be specified as dependencies\n", " elementwise_add_kernel(\n", " gA, gB, gC, tv_layout\n", " ).launch(\n", " grid=[cute.size(gC, mode=[1]), 1, 1],\n", " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", " )\n", "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", "\n", "a_ = from_dlpack(a, assumed_align=16)\n", "b_ = from_dlpack(b, assumed_align=16)\n", "c_ = from_dlpack(c, assumed_align=16)\n", "\n", "elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n", "elementwise_add_(a_, b_, c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, a + b)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average execution time: 0.0222 ms\n", "Throughput: 1133.58 GB/s\n" ] } ], "source": [ "benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Lambda Function\n", "\n", "CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n", "E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n", "\n", "\n", "```python\n", "@cute.jit\n", "def elementwise_apply(\n", " op: cutlass.Constexpr,\n", " mA: cute.Tensor,\n", " mB: cute.Tensor,\n", " mC: cute.Tensor\n", "):\n", " ...\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tiler: (16, 256)\n", "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", "Tiled Input Tensors:\n", " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", "Composed with TV layout:\n", " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" ] } ], "source": [ "@cute.kernel\n", "def elementwise_apply_kernel(\n", " op: cutlass.Constexpr, # lambda function must be const expr to generate code at compile time\n", " gA: cute.Tensor,\n", " gB: cute.Tensor,\n", " gC: cute.Tensor,\n", " tv_layout: cute.Layout\n", "):\n", " tidx, _, _ = cute.arch.thread_idx()\n", " bidx, _, _ = cute.arch.block_idx()\n", "\n", " blk_coord = ((None, None), bidx)\n", "\n", " # logical coord -> address\n", " blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n", " blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n", " blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n", "\n", " tidfrgA = cute.composition(blkA, tv_layout)\n", " tidfrgB = cute.composition(blkB, tv_layout)\n", " tidfrgC = cute.composition(blkC, tv_layout)\n", "\n", " print(f\"Composed with TV layout:\")\n", " print(f\" tidfrgA: {tidfrgA.type}\")\n", "\n", " thr_coord = (tidx, None)\n", "\n", " # slice for threads: vid -> address\n", " thrA = tidfrgA[thr_coord] # (V) -> physical address\n", " thrB = tidfrgB[thr_coord] # (V) -> physical address\n", " thrC = tidfrgC[thr_coord] # (V) -> physical address\n", "\n", " #--------------------------------\n", " # apply custom operation\n", " #--------------------------------\n", " thrC[None] = op(thrA.load(), thrB.load())\n", "\n", "\n", "@cute.jit\n", "def elementwise_op(\n", " op: cutlass.Constexpr,\n", " mA: cute.Tensor,\n", " mB: cute.Tensor,\n", " mC: cute.Tensor,\n", "):\n", " # mA layout: (M, N):(N, 1)\n", " # TV layout map thread & value index to (16, 256) logical tile\n", " # - contiguous thread index maps to mode-1 because input layout is contiguous on\n", " # mode-1 for coalesced load-store\n", " # - each thread load 8 contiguous element each row and load 4 rows\n", " thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n", " val_layout = cute.make_layout((4, 8), stride=(8, 1))\n", " tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n", " print(f\"Tiler: {tiler_mn}\")\n", " print(f\"TV Layout: {tv_layout}\")\n", "\n", " gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", " gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n", "\n", " print(f\"Tiled Input Tensors:\")\n", " print(f\" gA: {gA.type}\")\n", " print(f\" gB: {gB.type}\")\n", " print(f\" gC: {gC.type}\")\n", "\n", " # Launch the kernel asynchronously\n", " # Async token(s) can also be specified as dependencies\n", " elementwise_apply_kernel(\n", " op, gA, gB, gC, tv_layout\n", " ).launch(\n", " grid=[cute.size(gC, mode=[1]), 1, 1],\n", " block=[cute.size(tv_layout, mode=[0]), 1, 1],\n", " )\n", "\n", "a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n", "c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n", "\n", "a_ = from_dlpack(a, assumed_align=16)\n", "b_ = from_dlpack(b, assumed_align=16)\n", "c_ = from_dlpack(c, assumed_align=16)\n", "\n", "from operator import mul\n", "\n", "elementwise_op(mul, a_, b_, c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, mul(a, b))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Custom operators can be more complex. For example, here's a function that performs\n", "multiplication followed by ReLU:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tiler: (16, 256)\n", "TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n", "Tiled Input Tensors:\n", " gA: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gB: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", " gC: !cute.memref, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n", "Composed with TV layout:\n", " tidfrgA: !cute.memref, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n" ] } ], "source": [ "def mul_relu(a, b):\n", " tmp = a * b\n", " return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n", "\n", "\n", "# As we uses cute.where in customized operation, we need to create another relu function\n", "def mul_relu_ref(a, b):\n", " tmp = a * b\n", " return torch.relu(tmp)\n", "\n", "\n", "elementwise_op(mul_relu, a_, b_, c_)\n", "\n", "# verify correctness\n", "torch.testing.assert_close(c, mul_relu_ref(a, b))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }