Files
cutlass/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

600 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"from cutlass.cute.runtime import from_dlpack"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<style>\n",
"div.mermaid > svg {\n",
" width: 50% !important;\n",
" height: auto !important;\n",
"}\n",
"</style>\n",
"\n",
"# Tutorial: Warp Specialization with Async Pipeline in CuTe DSL\n",
"\n",
"This tutorial explores advanced CUDA programming techniques for implementing efficient producer-consumer \n",
"patterns using asynchronous communication primitives in the CuTe Domain Specific Language (DSL).\n",
"\n",
"## Foundation: Inter-Warp Communication Basics\n",
"\n",
"### Understanding CUDA Warps and Shared Memory\n",
"\n",
"A **warp** is the fundamental execution unit in CUDA, consisting of 32 threads that execute instructions in Single Instruction, \n",
"Multiple Thread (SIMT) fashion on a Streaming Multiprocessor (SM). Understanding warp-level programming is crucial for \n",
"achieving optimal GPU performance.\n",
"\n",
"**Key Concepts:**\n",
"- Warps execute in lockstep, making them ideal for SIMD operations\n",
"- Multiple warps within a thread block (CTA) can cooperate through shared memory\n",
"- Shared memory provides low-latency, high-bandwidth communication between threads\n",
"\n",
"### Shared Memory Architecture\n",
"\n",
"**Shared memory** serves as a programmer-managed cache with several important characteristics:\n",
"\n",
"- **Speed**: ~100x faster than global memory access\n",
"- **Scope**: Accessible by all threads within the same thread block\n",
"- **Organization**: Divided into banks (typically 32) to enable parallel access\n",
"- **Conflicts**: Bank conflicts occur when multiple threads access the same bank simultaneously\n",
"\n",
"### Traditional Synchronous Communication\n",
"\n",
"The conventional approach for inter-warp communication relies on explicit synchronization barriers. The following sequence diagram \n",
"illustrates the typical producer-consumer pattern:\n",
"\n",
"```mermaid\n",
"sequenceDiagram\n",
" participant W0 as Producer Warp\n",
" participant SMEM as Shared Memory\n",
" participant W1 as Consumer Warp\n",
" \n",
" W0->>SMEM: Write data\n",
" critical Synchronization Barrier\n",
" W0-->W1: __syncthreads()\n",
" SMEM->>W1: Read data\n",
" W0-->W1: __syncthreads()\n",
" end\n",
"```\n",
"\n",
"**Limitations of Synchronous Communication:**\n",
"- All warps must wait at synchronization points\n",
"- No opportunity for overlapped computation\n",
"- Reduced overall throughput due to forced serialization"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def synced_producer_consumer(SharedStorage: cutlass.Constexpr, res: cute.Tensor):\n",
" warp_idx = cute.arch.warp_idx()\n",
" warp_idx = cute.arch.make_warp_uniform(warp_idx)\n",
"\n",
" smem = cutlass.utils.SmemAllocator()\n",
" storage = smem.allocate(SharedStorage, 64)\n",
"\n",
" staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))\n",
" staging_smem.fill(0)\n",
" cute.arch.sync_threads()\n",
"\n",
" for i in cutlass.range(cute.size(res)):\n",
" if warp_idx == 0:\n",
" staging_smem[0] = i * 1.0\n",
" # mark enter of critical region\n",
" cute.arch.sync_threads()\n",
" if warp_idx == 1:\n",
" res[i] = staging_smem[0]\n",
" # mark exit of critical region\n",
" cute.arch.sync_threads()\n",
"\n",
"\n",
"@cute.jit\n",
"def run_synced_producer_consumer(res: cute.Tensor):\n",
" @cute.struct\n",
" class SharedStorage:\n",
" staging_buffer: cute.struct.Align[\n",
" cute.struct.MemRange[cutlass.Float32, 1], 1024\n",
" ]\n",
"\n",
" synced_producer_consumer(SharedStorage, res).launch(\n",
" grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n",
" )\n",
"\n",
"\n",
"res = torch.zeros((8,), device=\"cuda\")\n",
"run_synced_producer_consumer(from_dlpack(res))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"<style>\n",
"div.mermaid > svg {\n",
" width: 50% !important;\n",
" height: auto !important;\n",
"}\n",
"</style>\n",
"\n",
"## Asynchronous Communication: Breaking the Synchronization Bottleneck\n",
"\n",
"### The Problem with Synchronous Patterns\n",
"\n",
"The previous example demonstrates traditional synchronized communication between warps. While functional, this approach \n",
"has significant performance limitations:\n",
"\n",
"**Critical Section Analysis:**\n",
"- **First `__syncthreads()`**: Ensures data is written and ready for consumption\n",
"- **Second `__syncthreads()`**: Guarantees data has been consumed and memory can be safely overwritten\n",
"\n",
"**Performance Impact:**\n",
"- All warps are forced into lockstep execution\n",
"- No computational overlap between producer and consumer operations\n",
"- Wasted cycles as warps wait at synchronization barriers\n",
"\n",
"### Hopper Architecture: Enabling Asynchronous Primitives\n",
"\n",
"Starting with the Hopper architecture, CUDA introduced sophisticated asynchronous communication primitives that enable \n",
"**warp specialization**—allowing different warps to perform distinct, specialized roles while maintaining loose coupling.\n",
"\n",
"**Key Benefits:**\n",
"- **Overlapped Execution**: Producer and consumer warps can perform computations concurrently\n",
"- **Reduced Latency**: Eliminates unnecessary synchronization stalls\n",
"- **Better Resource Utilization**: Maximizes SM occupancy and throughput\n",
"\n",
"### Async Pipeline Communication Pattern\n",
"\n",
"The async pipeline abstraction provides a elegant solution for producer-consumer communication without rigid synchronization constraints:\n",
"\n",
"```mermaid\n",
"sequenceDiagram\n",
" participant W0 as Producer Warp\n",
" participant Pipeline as Async Pipeline\n",
" participant SMEM as Shared Memory \n",
" participant W1 as Consumer Warp\n",
" \n",
" W0->>Pipeline: Acquire (request write slot)\n",
" activate W1\n",
" Pipeline-->>W0: Grant access\n",
" deactivate W1\n",
" \n",
" W1->>Pipeline: Wait (for data availability)\n",
" activate Pipeline\n",
" \n",
" W0->>SMEM: Write data\n",
" W0->>Pipeline: Commit (signal data ready)\n",
" \n",
" Pipeline-->>W1: Data available\n",
" deactivate Pipeline\n",
" \n",
" activate W0\n",
" SMEM->>W1: Read data\n",
" deactivate W0\n",
" W1->>Pipeline: Release (mark slot available)\n",
"```\n",
"\n",
"**Async Pipeline Advantages:**\n",
"- **Non-blocking Operations**: Warps can perform other work while waiting\n",
"- **Fine-grained Control**: Explicit control over data readiness and consumption\n",
"- **Scalable**: Supports multiple producer-consumer pairs efficiently"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Async Pipeline API Reference\n",
"\n",
"The `PipelineAsync` abstraction in CuTe DSL provides a comprehensive set of primitives for implementing efficient producer-consumer patterns:\n",
"\n",
"#### Producer Operations\n",
"- **`PipelineProducer.acquire()`**: Blocks until a write slot becomes available (released by consumer)\n",
" - Returns with a handle pointing to a available slot immediately if there is\n",
" - Enables backpressure control to prevent buffer overflow\n",
" - **`PipelineProducer.acquire_and_advance()`** additionally moves the producer's write index to the next buffer slot\n",
"\n",
"- **`PipelineProducer.commit(PipelineProducer.ImmutableProducerHandle)`** / **`PipelineProducer.ImmutableProducerHandle.commit()`**: Signals that data has been written to the handle-pointed slot and is ready for consumption\n",
" - Triggers waiting consumers\n",
" - Maintains data consistency guarantees\n",
" - If no assigned handle, **`PipelineConsumerHandle.release()`** tracks its internal maintained handle (pointed to the last one it acquires)\n",
"\n",
"#### Consumer Operations \n",
"- **`PipelineConsumer.wait()`**: Blocks until data becomes available for reading\n",
" - Returns with a handle pointing to a committed slot when producer commits new data\n",
" - Supports timeout and polling variants\n",
" - **`PipelineConsumer.wait_and_advance()`** additionally moves the consumer's read index to the next buffer slot\n",
"\n",
"- **`PipelineConsumerHandle.release(PipelineConsumer.ImmutableConsumerHandle)`** / **`PipelineConsumer.ImmutableConsumerHandle.release()`**: Marks data as consumed and the handle-pointed slot as consumed and available for reuse\n",
" - Enables producers to acquire released slots\n",
" - Critical for preventing deadlock in circular buffers\n",
" - If no assigned handle, **`PipelineConsumerHandle.release()`** tracks its internal maintained handle (pointed to the last one it waits for)\n",
"\n",
"#### Disclaimer\n",
"\n",
"The `pipeline` APIs provided abstractions for developers to manage synchornization between warps, thread-blocks, etc. It doesn't provide deadlock-free guarantee. It's still developer's responsibility to write correct code to avoid deadlock.\n",
"\n",
"#### Performance Characteristics\n",
"\n",
"**Computational Overlap**: This asynchronous communication pattern enables limited but significant computational overlap:\n",
"- **Producer**: Can perform preprocessing, data transformation, or prefetching while consumer processes previous data\n",
"- **Consumer**: Can execute post-processing, result computation, or output operations while producer prepares next data\n",
"\n",
"**Memory Efficiency**: Explicit slot management ensures optimal memory utilization without unnecessary copying or buffering."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def async_pipeline_kernel(res: cute.Tensor):\n",
" warp_idx = cute.arch.warp_idx()\n",
" warp_idx = cute.arch.make_warp_uniform(warp_idx)\n",
"\n",
" @cute.struct\n",
" class SharedStorage:\n",
" tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]\n",
" staging_buffer: cute.struct.Align[\n",
" cute.struct.MemRange[cutlass.Float32, 1], 1024\n",
" ]\n",
"\n",
" smem = cutlass.utils.SmemAllocator()\n",
" storage = smem.allocate(SharedStorage, 64)\n",
"\n",
" # Warp 0\n",
" producer_group = cutlass.pipeline.CooperativeGroup(\n",
" cutlass.pipeline.Agent.Thread, 32\n",
" )\n",
" # Warp 1\n",
" consumer_group = cutlass.pipeline.CooperativeGroup(\n",
" cutlass.pipeline.Agent.Thread, 32\n",
" )\n",
"\n",
" pipeline = cutlass.pipeline.PipelineAsync.create(\n",
" num_stages=1,\n",
" producer_group=producer_group,\n",
" consumer_group=consumer_group,\n",
" barrier_storage=storage.tma_mbar_ptr.data_ptr(),\n",
" )\n",
"\n",
" staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))\n",
" staging_smem.fill(0)\n",
" cute.arch.sync_threads()\n",
"\n",
" producer, consumer = pipeline.make_participants()\n",
"\n",
" # Producer warp\n",
" if warp_idx == 0:\n",
" for i in cutlass.range(cute.size(res)):\n",
" # Producer: Wait for data buffer is available\n",
" handle = producer.acquire_and_advance()\n",
" # Producer: Write data to shared memory\n",
" staging_smem[handle.index] = 1.0 * i\n",
" # Producer: Signal data is ready for consumption\n",
" handle.commit()\n",
" producer.tail()\n",
"\n",
" # Consumer warp\n",
" if warp_idx == 1:\n",
" for i in cutlass.range(cute.size(res)):\n",
" # Consumer: Wait for producer to signal when data is available for use\n",
" handle = consumer.wait_and_advance()\n",
" # Conumer: consumes data\n",
" res[i] = staging_smem[handle.index]\n",
" # Conumer: Signal data buffer is ready for write\n",
" handle.release()\n",
"\n",
"\n",
"@cute.jit\n",
"def async_pipeline(res: cute.Tensor):\n",
" # Launch kernel with two warps: producer and consumer\n",
" async_pipeline_kernel(res).launch(grid=(1, 1, 1), block=(64, 1, 1))\n",
"\n",
"\n",
"res = torch.zeros((8,), device=\"cuda\")\n",
"async_pipeline(from_dlpack(res))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<style>\n",
"div.mermaid > svg {\n",
" width: 50% !important;\n",
" height: auto !important;\n",
"}\n",
"</style>\n",
"\n",
"## Advanced Pattern: Staged Async Pipeline with Circular Buffering\n",
"\n",
"### Limitations of Single-Stage Pipelines\n",
"\n",
"While async communication provides significant improvements over synchronous patterns, single-stage pipelines \n",
"still exhibit serialization bottlenecks:\n",
"\n",
"**Dependency Chain Analysis:**\n",
"```mermaid\n",
"sequenceDiagram\n",
" participant W0 as Producer\n",
" participant Pipeline as Pipeline\n",
" participant W1 as Consumer\n",
" \n",
" W0->>Pipeline: Acquire\n",
" Note over W0,W1: Producer waits here\n",
" W1->>Pipeline: Release\n",
" Pipeline-->>W0: Granted\n",
"```\n",
"\n",
"**Performance Bottleneck**: The producer must wait for the consumer to complete processing and release the buffer \n",
"before acquiring the next write slot. This creates a serialization point that limits overall throughput.\n",
"\n",
"### Multi-Stage Pipeline Architecture\n",
"\n",
"The **staged async pipeline** implements a circular buffer managed by an array of synchronization barriers, \n",
"enabling much higher degrees of parallelism:\n",
"\n",
"#### Core Concepts\n",
"\n",
"**Circular Buffer Management:**\n",
"- **Multiple Stages**: Support for N concurrent buffer slots (typically 2-8 stages)\n",
"- **Independent Indexing**: Producer and consumer maintain separate advancement indices\n",
"- **Barrier Array**: Each stage has an associated memory barrier for fine-grained synchronization\n",
"\n",
"#### Enhanced API Operations\n",
"\n",
"- **`PipelineProducer.advance()`**: Moves the producer's write index to the next buffer slot\n",
" - Enables round-robin buffer allocation\n",
" - Allows producer to continue without waiting for all previous data to be consumed\n",
" - Can be conducted implicitly when calling **`PipelineProducer.require_and_advance()`**\n",
"\n",
"- **`PipelineConsumer.advance()`**: Moves the consumer's read index to the next buffer slot\n",
" - Maintains proper ordering of data consumption\n",
" - Signals availability of processed slots\n",
" - Can be conducted implicitly when calling **`PipelineConsumer.wait_and_advance()`**\n",
"\n",
"- **`PipelineProducer.ImmutableResourceHandle.index`** / **`PipelineConsumer.ImmutableResourceHandle.index`**: Returns pointed buffer slot index\n",
" - Used for addressing specific staging buffer locations\n",
" - Enables direct slot-based data access\n",
"\n",
"### Circular Buffer State Visualization\n",
"\n",
"```\n",
"Legend:\n",
" W: Currently being written (producer active)\n",
" D: Data ready for consumption \n",
" R: Currently being read (consumer active)\n",
" X: Empty slot available for writing\n",
" \n",
" Advance Direction\n",
" <-------------------\n",
"\n",
" Producer Consumer\n",
" | ^\n",
" V |\n",
" +-----------------+\n",
" --|X|X|W|D|D|D|D|R|X|<-.\n",
" / +-----------------+ \\\n",
" | |\n",
" `------------------------' \n",
"```\n",
"\n",
"**Key Advantages:**\n",
"- **Increased Throughput**: Producer can stay ahead of consumer by multiple stages\n",
"- **Latency Hiding**: Consumer processing latency is hidden by buffered data\n",
"- **Better Resource Utilization**: Both warps can maintain high activity levels\n",
"- **Scalable Design**: Buffer depth can be tuned based on workload characteristics\n",
"\n",
"The following implementation demonstrates efficient multi-stage pipeline communication with proper circular buffer management:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def async_pipeline_staged_kernel(\n",
" SharedStorage: cutlass.Constexpr, res: cute.Tensor, staging: cute.Tensor\n",
"):\n",
" stages = cute.size(staging)\n",
"\n",
" warp_idx = cute.arch.warp_idx()\n",
" warp_idx = cute.arch.make_warp_uniform(warp_idx)\n",
"\n",
" smem = cutlass.utils.SmemAllocator()\n",
" storage = smem.allocate(SharedStorage, 64)\n",
"\n",
" # Warp 0\n",
" producer_group = cutlass.pipeline.CooperativeGroup(\n",
" cutlass.pipeline.Agent.Thread, 32\n",
" )\n",
" # Warp 1\n",
" consumer_group = cutlass.pipeline.CooperativeGroup(\n",
" cutlass.pipeline.Agent.Thread, 32\n",
" )\n",
"\n",
" pipeline = cutlass.pipeline.PipelineAsync.create(\n",
" num_stages=stages,\n",
" producer_group=producer_group,\n",
" consumer_group=consumer_group,\n",
" barrier_storage=storage.tma_mbar_ptr.data_ptr(),\n",
" )\n",
"\n",
" staging_smem = storage.staging_buffer.get_tensor(staging.layout)\n",
" staging_smem.fill(0)\n",
" cute.arch.sync_threads()\n",
"\n",
" producer, consumer = pipeline.make_participants()\n",
"\n",
" # Producer warp\n",
" if warp_idx == 0:\n",
" for i in cutlass.range(cute.size(res)):\n",
" handle = producer.acquire_and_advance()\n",
" staging_smem[handle.index] = 1.0 * i\n",
" handle.commit() # or producer.commit(handle)\n",
"\n",
" # prevents CTA0 from retiring until it receives all expected arrives.\n",
" producer.tail()\n",
"\n",
" # Consumer warp\n",
" if warp_idx == 1:\n",
" for i in cutlass.range(cute.size(res)):\n",
" handle = consumer.wait_and_advance()\n",
" res[i] = staging_smem[handle.index]\n",
" handle.release() # or consumer.release(handle)\n",
"\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" if tidx == 0:\n",
" staging.store(staging_smem.load())\n",
"\n",
"\n",
"@cute.jit\n",
"def async_pipeline_staged(res: cute.Tensor, staging: cute.Tensor):\n",
" stages = cute.size(staging)\n",
"\n",
" @cute.struct\n",
" class SharedStorage:\n",
" tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, stages * 2]\n",
" staging_buffer: cute.struct.Align[\n",
" cute.struct.MemRange[cutlass.Float32, stages], 1024\n",
" ]\n",
"\n",
" async_pipeline_staged_kernel(SharedStorage, res, staging).launch(\n",
" grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n",
" )\n",
"\n",
"\n",
"res = torch.zeros((8,), device=\"cuda\")\n",
"staging = torch.zeros((5,), device=\"cuda\")\n",
"async_pipeline_staged(from_dlpack(res), from_dlpack(staging))\n",
"torch.cuda.synchronize()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='cuda:0'),\n",
" tensor([5., 6., 7., 3., 4.], device='cuda:0'))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res, staging"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Try Acquire/Wait\n",
"\n",
"In some circumstances, developers may want to just check status of pipeline state without blocking. This could benefit some cases that we have independent instructions to hide latency of checking pipeline state. We provided `try_aquire` or `try_wait` which are non-blocking APIs. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}