649 lines
17 KiB
Plaintext
649 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0e95f0df-4d1a-4e2e-92ff-90539bb4c517",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 06: CUDA Graphs\n",
|
|
"\n",
|
|
"In this example we demonstrate how to use CUDA graphs through PyTorch with CuTe DSL.\n",
|
|
"The process of interacting with PyTorch's CUDA graph implementation requires exposing PyTorch's CUDA streams to CUTLASS.\n",
|
|
"\n",
|
|
"To use CUDA graphs with Blackwell requires a version of PyTorch that supports Blackwell.\n",
|
|
"This can be obtained through:\n",
|
|
"- The [PyTorch NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\n",
|
|
"- [PyTorch 2.7 with CUDA 12.8 or later](https://pytorch.org/) (e.g., `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`)\n",
|
|
"- Building PyTorch directly with your version of CUDA."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "46b8fb6f-9ac5-4a3d-b765-b6476f182bf7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import torch for CUDA graphs\n",
|
|
"import torch\n",
|
|
"import cutlass\n",
|
|
"import cutlass.cute as cute\n",
|
|
"# import CUstream type from the cuda driver bindings\n",
|
|
"from cuda.bindings.driver import CUstream\n",
|
|
"# import the current_stream function from torch\n",
|
|
"from torch.cuda import current_stream"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bcf5e06e-1f5b-4d72-ad73-9b36efb78ca0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Kernel Creation\n",
|
|
"\n",
|
|
"We create a kernel which prints \"Hello world\" as well as a host function to launch the kernel.\n",
|
|
"We then compile the kernel for use in our graph, by passing in a default stream.\n",
|
|
"\n",
|
|
"Kernel compilation before graph capture is required since CUDA graphs cannot JIT compile kernels during graph execution."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "0c2a6ca8-98d7-4837-b91f-af769ca8fcd8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@cute.kernel\n",
|
|
"def hello_world_kernel():\n",
|
|
" \"\"\"\n",
|
|
" A kernel that prints hello world\n",
|
|
" \"\"\"\n",
|
|
" cute.printf(\"Hello world\")\n",
|
|
"\n",
|
|
"@cute.jit\n",
|
|
"def hello_world(stream : CUstream):\n",
|
|
" \"\"\"\n",
|
|
" Host function that launches our (1,1,1), (1,1,1) grid in stream\n",
|
|
" \"\"\"\n",
|
|
" hello_world_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)\n",
|
|
"\n",
|
|
"# Grab a stream from PyTorch, this will also initialize our context\n",
|
|
"# so we can omit cutlass.cuda.initialize_cuda_context()\n",
|
|
"stream = current_stream()\n",
|
|
"hello_world_compiled = cute.compile(hello_world, CUstream(stream.cuda_stream))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ecc850af-09f8-4a29-9c93-ff31fbb9326f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Creating and replaying a CUDA Graph\n",
|
|
"\n",
|
|
"We create a stream through torch as well as a graph.\n",
|
|
"When we create the graph we can pass the stream we want to capture to torch. We similarly run the compiled kernel with the stream passed as a CUstream.\n",
|
|
"\n",
|
|
"Finally we can replay our graph and synchronize."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "f673e5ae-42bb-44d0-b652-3280606181c4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Hello world\n",
|
|
"Hello world\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create a CUDA Graph\n",
|
|
"g = torch.cuda.CUDAGraph()\n",
|
|
"# Capture our graph\n",
|
|
"with torch.cuda.graph(g):\n",
|
|
" # Turn our torch Stream into a cuStream stream.\n",
|
|
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
|
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
|
" # Run 2 iterations of our compiled kernel\n",
|
|
" for _ in range(2):\n",
|
|
" # Run our kernel in the stream\n",
|
|
" hello_world_compiled(graph_stream)\n",
|
|
"\n",
|
|
"# Replay our graph\n",
|
|
"g.replay()\n",
|
|
"# Synchronize all streams (equivalent to cudaDeviceSynchronize() in C++)\n",
|
|
"torch.cuda.synchronize()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "db76d9c3-7617-4bf2-b326-11982e6803bf",
|
|
"metadata": {},
|
|
"source": [
|
|
"Our run results in the following execution when viewed in NSight Systems:\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"We can observe the launch of the two kernels followed by a `cudaDeviceSynchronize()`.\n",
|
|
"\n",
|
|
"Now we can confirm that this minimizes some launch overhead:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "3ebe15bf-dc97-42e9-913c-224ecfb472e8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n",
|
|
"Hello world\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Get our CUDA stream from PyTorch\n",
|
|
"stream = CUstream(current_stream().cuda_stream)\n",
|
|
"\n",
|
|
"# Create a larger CUDA Graph of 100 iterations\n",
|
|
"g = torch.cuda.CUDAGraph()\n",
|
|
"# Capture our graph\n",
|
|
"with torch.cuda.graph(g):\n",
|
|
" # Turn our torch Stream into a cuStream stream.\n",
|
|
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
|
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
|
" # Run 2 iterations of our compiled kernel\n",
|
|
" for _ in range(100):\n",
|
|
" # Run our kernel in the stream\n",
|
|
" hello_world_compiled(graph_stream)\n",
|
|
"\n",
|
|
"# Create CUDA events for measuring performance\n",
|
|
"start = torch.cuda.Event(enable_timing=True)\n",
|
|
"end = torch.cuda.Event(enable_timing=True)\n",
|
|
"\n",
|
|
"# Run our kernel to warm up the GPU\n",
|
|
"for _ in range(100):\n",
|
|
" hello_world_compiled(stream)\n",
|
|
"\n",
|
|
"# Record our start time\n",
|
|
"start.record()\n",
|
|
"# Run 100 kernels\n",
|
|
"for _ in range(100):\n",
|
|
" hello_world_compiled(stream)\n",
|
|
"# Record our end time\n",
|
|
"end.record()\n",
|
|
"# Synchronize (cudaDeviceSynchronize())\n",
|
|
"torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"# Calculate the time spent when launching kernels in a stream\n",
|
|
"# Results are in ms\n",
|
|
"stream_time = start.elapsed_time(end) \n",
|
|
"\n",
|
|
"# Warmup our GPU again\n",
|
|
"g.replay()\n",
|
|
"# Record our start time\n",
|
|
"start.record()\n",
|
|
"# Run our graph\n",
|
|
"g.replay()\n",
|
|
"# Record our end time\n",
|
|
"end.record()\n",
|
|
"# Synchronize (cudaDeviceSynchronize())\n",
|
|
"torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"# Calculate the time spent when launching kernels in a graph\n",
|
|
"# units are ms\n",
|
|
"graph_time = start.elapsed_time(end)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "12b8151a-46b3-4c99-9945-301f6b628131",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"8.94% speedup when using CUDA graphs for this kernel!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Print out speedup when using CUDA graphs\n",
|
|
"percent_speedup = (stream_time - graph_time) / graph_time\n",
|
|
"print(f\"{percent_speedup * 100.0:.2f}% speedup when using CUDA graphs for this kernel!\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|