1002 lines
36 KiB
Plaintext
1002 lines
36 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Using CuTe Layout Algebra With Python DSL\n",
|
||
"\n",
|
||
"Referencing the [01_layout.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md) and [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) documentation from CuTe C++, we summarize:\n",
|
||
"\n",
|
||
"In CuTe, a `Layout`:\n",
|
||
"- is defined by a pair of `Shape` and `Stride`,\n",
|
||
"- maps coordinates space(s) to an index space,\n",
|
||
"- supports both static (compile-time) and dynamic (runtime) values.\n",
|
||
"\n",
|
||
"CuTe also provides a powerful set of operations—the *Layout Algebra*—for combining and manipulating layouts, including:\n",
|
||
"- Layout composition: Functional composition of layouts,\n",
|
||
"- Layout \"divide\": Splitting a layout into two component layouts,\n",
|
||
"- Layout \"product\": Reproducing a layout according to another layout.\n",
|
||
"\n",
|
||
"In this notebook, we will demonstrate:\n",
|
||
"1. How to use CuTe’s key layout algebra operations with the Python DSL.\n",
|
||
"2. How static and dynamic layouts behave when printed or manipulated within the Python DSL.\n",
|
||
"\n",
|
||
"We use examples from [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) which we recommend to the reader for additional details."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import cutlass\n",
|
||
"import cutlass.cute as cute"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Layout Algebra Operations\n",
|
||
"\n",
|
||
"These operations form the foundation of CuTe's layout manipulation capabilities, enabling:\n",
|
||
"- Efficient data tiling and partitioning,\n",
|
||
"- Separation of thread and data layouts with a canonical type to represent both,\n",
|
||
"- Native description and manipulation of hierarchical tensors of threads and data crucial for tensor core programs,\n",
|
||
"- Mixed static/dynamic layout transformations,\n",
|
||
"- Seamless integration of layout algebra with tensor operations,\n",
|
||
"- Expression of complex MMA and copies as canonical loops."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1. Coalesce\n",
|
||
"\n",
|
||
"The `coalesce` operation simplifies a layout by flattening and combining modes when possible, without changing its size or behavior as a function on the integers.\n",
|
||
"\n",
|
||
"It ensures the post-conditions:\n",
|
||
"- Preserve size: cute.size(layout) == cute.size(result),\n",
|
||
"- Flattened: depth(result) <= 1,\n",
|
||
"- Preserve functional: For all i, 0 <= i < cute.size(layout), layout(i) == result(i).\n",
|
||
"\n",
|
||
"#### Examples\n",
|
||
"\n",
|
||
"- Basic Coalesce Example :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Original: (2,(1,6)):(1,(?,2))\n",
|
||
">>> Coalesced: 12:1\n",
|
||
">?? Original: (2,(1,6)):(1,(6,2))\n",
|
||
">?? Coalesced: 12:1\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def coalesce_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates coalesce operation flattening and combining modes\n",
|
||
" \"\"\"\n",
|
||
" layout = cute.make_layout((2, (1, 6)), stride=(1, (cutlass.Int32(6), 2))) # Dynamic stride\n",
|
||
" result = cute.coalesce(layout)\n",
|
||
"\n",
|
||
" print(\">>> Original:\", layout)\n",
|
||
" cute.printf(\">?? Original: {}\", layout)\n",
|
||
" print(\">>> Coalesced:\", result)\n",
|
||
" cute.printf(\">?? Coalesced: {}\", result)\n",
|
||
"\n",
|
||
"coalesce_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Original: ((2,(3,4)),(3,2),1):((4,(8,24)),(2,6),12)\n",
|
||
">>> Coalesced: (24,6):(4,2)\n",
|
||
">>> Checking post-conditions:\n",
|
||
">>> 1. Checking size remains the same after the coalesce operation:\n",
|
||
"Original size: 144, Coalesced size: 144\n",
|
||
">>> 2. Checking depth of coalesced layout <= 1:\n",
|
||
"Depth of coalesced layout: 1\n",
|
||
">>> 3. Checking layout functionality remains the same after the coalesce operation:\n",
|
||
"Index 0: original 0, coalesced 0\n",
|
||
"Index 1: original 4, coalesced 4\n",
|
||
"Index 2: original 8, coalesced 8\n",
|
||
"Index 3: original 12, coalesced 12\n",
|
||
"Index 4: original 16, coalesced 16\n",
|
||
"Index 5: original 20, coalesced 20\n",
|
||
"Index 6: original 24, coalesced 24\n",
|
||
"Index 7: original 28, coalesced 28\n",
|
||
"Index 8: original 32, coalesced 32\n",
|
||
"Index 9: original 36, coalesced 36\n",
|
||
"Index 10: original 40, coalesced 40\n",
|
||
"Index 11: original 44, coalesced 44\n",
|
||
"Index 12: original 48, coalesced 48\n",
|
||
"Index 13: original 52, coalesced 52\n",
|
||
"Index 14: original 56, coalesced 56\n",
|
||
"Index 15: original 60, coalesced 60\n",
|
||
"Index 16: original 64, coalesced 64\n",
|
||
"Index 17: original 68, coalesced 68\n",
|
||
"Index 18: original 72, coalesced 72\n",
|
||
"Index 19: original 76, coalesced 76\n",
|
||
"Index 20: original 80, coalesced 80\n",
|
||
"Index 21: original 84, coalesced 84\n",
|
||
"Index 22: original 88, coalesced 88\n",
|
||
"Index 23: original 92, coalesced 92\n",
|
||
"Index 24: original 2, coalesced 2\n",
|
||
"Index 25: original 6, coalesced 6\n",
|
||
"Index 26: original 10, coalesced 10\n",
|
||
"Index 27: original 14, coalesced 14\n",
|
||
"Index 28: original 18, coalesced 18\n",
|
||
"Index 29: original 22, coalesced 22\n",
|
||
"Index 30: original 26, coalesced 26\n",
|
||
"Index 31: original 30, coalesced 30\n",
|
||
"Index 32: original 34, coalesced 34\n",
|
||
"Index 33: original 38, coalesced 38\n",
|
||
"Index 34: original 42, coalesced 42\n",
|
||
"Index 35: original 46, coalesced 46\n",
|
||
"Index 36: original 50, coalesced 50\n",
|
||
"Index 37: original 54, coalesced 54\n",
|
||
"Index 38: original 58, coalesced 58\n",
|
||
"Index 39: original 62, coalesced 62\n",
|
||
"Index 40: original 66, coalesced 66\n",
|
||
"Index 41: original 70, coalesced 70\n",
|
||
"Index 42: original 74, coalesced 74\n",
|
||
"Index 43: original 78, coalesced 78\n",
|
||
"Index 44: original 82, coalesced 82\n",
|
||
"Index 45: original 86, coalesced 86\n",
|
||
"Index 46: original 90, coalesced 90\n",
|
||
"Index 47: original 94, coalesced 94\n",
|
||
"Index 48: original 4, coalesced 4\n",
|
||
"Index 49: original 8, coalesced 8\n",
|
||
"Index 50: original 12, coalesced 12\n",
|
||
"Index 51: original 16, coalesced 16\n",
|
||
"Index 52: original 20, coalesced 20\n",
|
||
"Index 53: original 24, coalesced 24\n",
|
||
"Index 54: original 28, coalesced 28\n",
|
||
"Index 55: original 32, coalesced 32\n",
|
||
"Index 56: original 36, coalesced 36\n",
|
||
"Index 57: original 40, coalesced 40\n",
|
||
"Index 58: original 44, coalesced 44\n",
|
||
"Index 59: original 48, coalesced 48\n",
|
||
"Index 60: original 52, coalesced 52\n",
|
||
"Index 61: original 56, coalesced 56\n",
|
||
"Index 62: original 60, coalesced 60\n",
|
||
"Index 63: original 64, coalesced 64\n",
|
||
"Index 64: original 68, coalesced 68\n",
|
||
"Index 65: original 72, coalesced 72\n",
|
||
"Index 66: original 76, coalesced 76\n",
|
||
"Index 67: original 80, coalesced 80\n",
|
||
"Index 68: original 84, coalesced 84\n",
|
||
"Index 69: original 88, coalesced 88\n",
|
||
"Index 70: original 92, coalesced 92\n",
|
||
"Index 71: original 96, coalesced 96\n",
|
||
"Index 72: original 6, coalesced 6\n",
|
||
"Index 73: original 10, coalesced 10\n",
|
||
"Index 74: original 14, coalesced 14\n",
|
||
"Index 75: original 18, coalesced 18\n",
|
||
"Index 76: original 22, coalesced 22\n",
|
||
"Index 77: original 26, coalesced 26\n",
|
||
"Index 78: original 30, coalesced 30\n",
|
||
"Index 79: original 34, coalesced 34\n",
|
||
"Index 80: original 38, coalesced 38\n",
|
||
"Index 81: original 42, coalesced 42\n",
|
||
"Index 82: original 46, coalesced 46\n",
|
||
"Index 83: original 50, coalesced 50\n",
|
||
"Index 84: original 54, coalesced 54\n",
|
||
"Index 85: original 58, coalesced 58\n",
|
||
"Index 86: original 62, coalesced 62\n",
|
||
"Index 87: original 66, coalesced 66\n",
|
||
"Index 88: original 70, coalesced 70\n",
|
||
"Index 89: original 74, coalesced 74\n",
|
||
"Index 90: original 78, coalesced 78\n",
|
||
"Index 91: original 82, coalesced 82\n",
|
||
"Index 92: original 86, coalesced 86\n",
|
||
"Index 93: original 90, coalesced 90\n",
|
||
"Index 94: original 94, coalesced 94\n",
|
||
"Index 95: original 98, coalesced 98\n",
|
||
"Index 96: original 8, coalesced 8\n",
|
||
"Index 97: original 12, coalesced 12\n",
|
||
"Index 98: original 16, coalesced 16\n",
|
||
"Index 99: original 20, coalesced 20\n",
|
||
"Index 100: original 24, coalesced 24\n",
|
||
"Index 101: original 28, coalesced 28\n",
|
||
"Index 102: original 32, coalesced 32\n",
|
||
"Index 103: original 36, coalesced 36\n",
|
||
"Index 104: original 40, coalesced 40\n",
|
||
"Index 105: original 44, coalesced 44\n",
|
||
"Index 106: original 48, coalesced 48\n",
|
||
"Index 107: original 52, coalesced 52\n",
|
||
"Index 108: original 56, coalesced 56\n",
|
||
"Index 109: original 60, coalesced 60\n",
|
||
"Index 110: original 64, coalesced 64\n",
|
||
"Index 111: original 68, coalesced 68\n",
|
||
"Index 112: original 72, coalesced 72\n",
|
||
"Index 113: original 76, coalesced 76\n",
|
||
"Index 114: original 80, coalesced 80\n",
|
||
"Index 115: original 84, coalesced 84\n",
|
||
"Index 116: original 88, coalesced 88\n",
|
||
"Index 117: original 92, coalesced 92\n",
|
||
"Index 118: original 96, coalesced 96\n",
|
||
"Index 119: original 100, coalesced 100\n",
|
||
"Index 120: original 10, coalesced 10\n",
|
||
"Index 121: original 14, coalesced 14\n",
|
||
"Index 122: original 18, coalesced 18\n",
|
||
"Index 123: original 22, coalesced 22\n",
|
||
"Index 124: original 26, coalesced 26\n",
|
||
"Index 125: original 30, coalesced 30\n",
|
||
"Index 126: original 34, coalesced 34\n",
|
||
"Index 127: original 38, coalesced 38\n",
|
||
"Index 128: original 42, coalesced 42\n",
|
||
"Index 129: original 46, coalesced 46\n",
|
||
"Index 130: original 50, coalesced 50\n",
|
||
"Index 131: original 54, coalesced 54\n",
|
||
"Index 132: original 58, coalesced 58\n",
|
||
"Index 133: original 62, coalesced 62\n",
|
||
"Index 134: original 66, coalesced 66\n",
|
||
"Index 135: original 70, coalesced 70\n",
|
||
"Index 136: original 74, coalesced 74\n",
|
||
"Index 137: original 78, coalesced 78\n",
|
||
"Index 138: original 82, coalesced 82\n",
|
||
"Index 139: original 86, coalesced 86\n",
|
||
"Index 140: original 90, coalesced 90\n",
|
||
"Index 141: original 94, coalesced 94\n",
|
||
"Index 142: original 98, coalesced 98\n",
|
||
"Index 143: original 102, coalesced 102\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def coalesce_post_conditions():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates coalesce operation's 3 post-conditions:\n",
|
||
" 1. size(@a result) == size(@a layout)\n",
|
||
" 2. depth(@a result) <= 1\n",
|
||
" 3. for all i, 0 <= i < size(@a layout), @a result(i) == @a layout(i)\n",
|
||
" \"\"\"\n",
|
||
" layout = cute.make_layout(\n",
|
||
" ((2, (3, 4)), (3, 2), 1),\n",
|
||
" stride=((4, (8, 24)), (2, 6), 12)\n",
|
||
" )\n",
|
||
" result = cute.coalesce(layout)\n",
|
||
"\n",
|
||
" print(\">>> Original:\", layout)\n",
|
||
" print(\">>> Coalesced:\", result)\n",
|
||
"\n",
|
||
" print(\">>> Checking post-conditions:\")\n",
|
||
" print(\">>> 1. Checking size remains the same after the coalesce operation:\")\n",
|
||
" original_size = cute.size(layout)\n",
|
||
" coalesced_size = cute.size(result)\n",
|
||
" print(f\"Original size: {original_size}, Coalesced size: {coalesced_size}\")\n",
|
||
" assert coalesced_size == original_size, \\\n",
|
||
" f\"Size mismatch: original {original_size}, coalesced {coalesced_size}\"\n",
|
||
" \n",
|
||
" print(\">>> 2. Checking depth of coalesced layout <= 1:\")\n",
|
||
" depth = cute.depth(result)\n",
|
||
" print(f\"Depth of coalesced layout: {depth}\")\n",
|
||
" assert depth <= 1, f\"Depth of coalesced layout should be <= 1, got {depth}\"\n",
|
||
"\n",
|
||
" print(\">>> 3. Checking layout functionality remains the same after the coalesce operation:\")\n",
|
||
" for i in cutlass.range_constexpr(original_size):\n",
|
||
" original_value = layout(i)\n",
|
||
" coalesced_value = result(i)\n",
|
||
" print(f\"Index {i}: original {original_value}, coalesced {coalesced_value}\")\n",
|
||
" assert coalesced_value == original_value, \\\n",
|
||
" f\"Value mismatch at index {i}: original {original_value}, coalesced {coalesced_value}\"\n",
|
||
"\n",
|
||
"coalesce_post_conditions()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- By-mode Coalesce Example :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Original: (2,(1,6)):(1,(6,2))\n",
|
||
">>> Coalesced Result: (2,6):(1,2)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def bymode_coalesce_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates by-mode coalescing\n",
|
||
" \"\"\"\n",
|
||
" layout = cute.make_layout((2, (1, 6)), stride=(1, (6, 2)))\n",
|
||
"\n",
|
||
" # Coalesce with mode-wise profile (1,1) = coalesce both modes\n",
|
||
" result = cute.coalesce(layout, target_profile=(1, 1))\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Original: \", layout)\n",
|
||
" print(\">>> Coalesced Result: \", result)\n",
|
||
"\n",
|
||
"bymode_coalesce_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2. Composition\n",
|
||
"\n",
|
||
"`Composition` of Layout `A` with Layout `B` creates a new layout `R = A ◦ B` where:\n",
|
||
"- The shape of `B` is compatible with the shape of `R` so that all coordinates of `B` can also be used as coordinates of `R`,\n",
|
||
"- `R(c) = A(B(c))` for all coordinates `c` in `B`'s domain.\n",
|
||
"\n",
|
||
"Layout composition is very useful for reshaping and reordering layouts.\n",
|
||
"\n",
|
||
"#### Examples\n",
|
||
"\n",
|
||
"- Basic Composition Example :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout A: (6,2):(?,2)\n",
|
||
">>> Layout B: (4,3):(3,1)\n",
|
||
">>> Composition R = A ◦ B: ((2,2),3):((?{div=3},2),?)\n",
|
||
">?? Layout A: (6,2):(8,2)\n",
|
||
">?? Layout B: (4,3):(3,1)\n",
|
||
">?? Composition R: ((2,2),3):((24,2),8)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def composition_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates basic layout composition R = A ◦ B\n",
|
||
" \"\"\"\n",
|
||
" A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2)) # Dynamic stride\n",
|
||
" B = cute.make_layout((4, 3), stride=(3, 1))\n",
|
||
" R = cute.composition(A, B)\n",
|
||
"\n",
|
||
" # Print static and dynamic information\n",
|
||
" print(\">>> Layout A:\", A)\n",
|
||
" cute.printf(\">?? Layout A: {}\", A)\n",
|
||
" print(\">>> Layout B:\", B) \n",
|
||
" cute.printf(\">?? Layout B: {}\", B)\n",
|
||
" print(\">>> Composition R = A ◦ B:\", R)\n",
|
||
" cute.printf(\">?? Composition R: {}\", R)\n",
|
||
"\n",
|
||
"composition_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- Comparing Composition with static and dynamic layouts :\n",
|
||
"\n",
|
||
"In this case, the results may look different but are mathematically the same. The 1s in the shape don't affect the layout as a mathematical function on the integers. In the dynamic case, CuTe can not coalesce the dynamic size-1 modes to \"simplify\" the layout because it is not valid to do so for all possible dynamic values that parameter could realize at runtime."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Static composition:\n",
|
||
">>> A_static: (10,2):(16,4)\n",
|
||
">>> B_static: (5,4):(1,5)\n",
|
||
">>> R_static: (5,(2,2)):(16,(80,4))\n",
|
||
">?? Dynamic composition:\n",
|
||
">?? A_dynamic: (10,2):(16,4)\n",
|
||
">?? B_dynamic: (5,4):(1,5)\n",
|
||
">?? R_dynamic: ((5,1),(2,2)):((16,4),(80,4))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def composition_static_vs_dynamic_layout():\n",
|
||
" \"\"\"\n",
|
||
" Shows difference between static and dynamic composition results\n",
|
||
" \"\"\"\n",
|
||
" # Static version - using compile-time values\n",
|
||
" A_static = cute.make_layout(\n",
|
||
" (10, 2), \n",
|
||
" stride=(16, 4)\n",
|
||
" )\n",
|
||
" B_static = cute.make_layout(\n",
|
||
" (5, 4), \n",
|
||
" stride=(1, 5)\n",
|
||
" )\n",
|
||
" R_static = cute.composition(A_static, B_static)\n",
|
||
"\n",
|
||
" # Static print shows compile-time info\n",
|
||
" print(\">>> Static composition:\")\n",
|
||
" print(\">>> A_static: \", A_static)\n",
|
||
" print(\">>> B_static: \", B_static)\n",
|
||
" print(\">>> R_static: \", R_static)\n",
|
||
"\n",
|
||
" # Dynamic version - using runtime Int32 values\n",
|
||
" A_dynamic = cute.make_layout(\n",
|
||
" (cutlass.Int32(10), cutlass.Int32(2)),\n",
|
||
" stride=(cutlass.Int32(16), cutlass.Int32(4))\n",
|
||
" )\n",
|
||
" B_dynamic = cute.make_layout(\n",
|
||
" (cutlass.Int32(5), cutlass.Int32(4)),\n",
|
||
" stride=(cutlass.Int32(1), cutlass.Int32(5))\n",
|
||
" )\n",
|
||
" R_dynamic = cute.composition(A_dynamic, B_dynamic)\n",
|
||
" \n",
|
||
" # Dynamic printf shows runtime values\n",
|
||
" cute.printf(\">?? Dynamic composition:\")\n",
|
||
" cute.printf(\">?? A_dynamic: {}\", A_dynamic)\n",
|
||
" cute.printf(\">?? B_dynamic: {}\", B_dynamic)\n",
|
||
" cute.printf(\">?? R_dynamic: {}\", R_dynamic)\n",
|
||
"\n",
|
||
"composition_static_vs_dynamic_layout()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- By-mode Composition Example :\n",
|
||
"\n",
|
||
"By-mode composition allows us to apply composition operations to individual modes of a layout. This is particularly useful when you want to manipulate specific modes layout independently (e.g. rows and columns).\n",
|
||
"\n",
|
||
"In the context of CuTe, by-mode composition is achieved by using a `Tiler`, which can be a layout or a tuple of layouts. The leaves of the `Tiler` tuple specify how the corresponding mode of the target layout should be composed, allowing for sublayouts to be treated independently."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout A: (?,(?,?)):(?,(?,?))\n",
|
||
">>> Tiler: (3, 8)\n",
|
||
">>> By-mode Composition Result: (3,(?,?)):(?,(?,?))\n",
|
||
">?? Layout A: (12,(4,8)):(59,(13,1))\n",
|
||
">?? Tiler: (3,8)\n",
|
||
">?? By-mode Composition Result: (3,(4,2)):(59,(13,1))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def bymode_composition_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates by-mode composition using a tiler\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout A\n",
|
||
" A = cute.make_layout(\n",
|
||
" (cutlass.Int32(12), (cutlass.Int32(4), cutlass.Int32(8))), \n",
|
||
" stride=(cutlass.Int32(59), (cutlass.Int32(13), cutlass.Int32(1)))\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Define the tiler for by-mode composition\n",
|
||
" tiler = (3, 8) # Apply 3:1 to mode-0 and 8:1 to mode-1\n",
|
||
"\n",
|
||
" # Apply by-mode composition\n",
|
||
" result = cute.composition(A, tiler)\n",
|
||
"\n",
|
||
" # Print static and dynamic information\n",
|
||
" print(\">>> Layout A:\", A)\n",
|
||
" cute.printf(\">?? Layout A: {}\", A)\n",
|
||
" print(\">>> Tiler:\", tiler)\n",
|
||
" cute.printf(\">?? Tiler: {}\", tiler)\n",
|
||
" print(\">>> By-mode Composition Result:\", result)\n",
|
||
" cute.printf(\">?? By-mode Composition Result: {}\", result)\n",
|
||
"\n",
|
||
"bymode_composition_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 3. Division (Splitting into Tiles)\n",
|
||
"\n",
|
||
"The Division operation in CuTe is used to split a layout into tiles, which is particularly useful for partitioning data across threads or memory hierarchies.\n",
|
||
"\n",
|
||
"#### Examples :\n",
|
||
"\n",
|
||
"- Logical divide :\n",
|
||
"\n",
|
||
"When applied to two Layouts, `logical_divide` splits a layout into two modes -- the first mode contains the elements pointed to by the tiler, and the second mode contains the remaining elements."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (4,2,3):(2,1,8)\n",
|
||
">>> Tiler : 4:2\n",
|
||
">>> Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))\n",
|
||
">?? Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def logical_divide_1d_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates 1D logical divide\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((4, 2, 3), stride=(2, 1, 8)) # (4,2,3):(2,1,8)\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = cute.make_layout(4, stride=2) # Apply to layout 4:2\n",
|
||
" \n",
|
||
" # Apply logical divide\n",
|
||
" result = cute.logical_divide(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Logical Divide Result:\", result)\n",
|
||
" cute.printf(\">?? Logical Divide Result: {}\", result)\n",
|
||
"\n",
|
||
"logical_divide_1d_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"When applied to a Layout and a `Tiler` tuple, `logical_divide` applies itself to the leaves of the `Tiler`and the corresponding mode of the target Layout. This means that the sublayouts are split independently according to the layouts within the `Tiler`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (9,(4,8)):(59,(13,1))\n",
|
||
">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc958160f50>)\n",
|
||
">>> Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))\n",
|
||
">?? Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def logical_divide_2d_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates 2D logical divide :\n",
|
||
" Layout Shape : (M, N, L, ...)\n",
|
||
" Tiler Shape : <TileM, TileN>\n",
|
||
" Result Shape : ((TileM,RestM), (TileN,RestN), L, ...)\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n",
|
||
" cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n",
|
||
" \n",
|
||
" # Apply logical divide\n",
|
||
" result = cute.logical_divide(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Logical Divide Result:\", result)\n",
|
||
" cute.printf(\">?? Logical Divide Result: {}\", result)\n",
|
||
"\n",
|
||
"logical_divide_2d_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Zipped, tiled, and flat divide are flavors of `logical_divide` that potentially rearrange modes into more convenient forms.\n",
|
||
"\n",
|
||
"- Zipped Divide :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (9,(4,8)):(59,(13,1))\n",
|
||
">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc9581611f0>)\n",
|
||
">>> Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))\n",
|
||
">?? Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def zipped_divide_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates zipped divide :\n",
|
||
" Layout Shape : (M, N, L, ...)\n",
|
||
" Tiler Shape : <TileM, TileN>\n",
|
||
" Result Shape : ((TileM,TileN), (RestM,RestN,L,...))\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n",
|
||
" cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n",
|
||
" \n",
|
||
" # Apply zipped divide\n",
|
||
" result = cute.zipped_divide(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Zipped Divide Result:\", result)\n",
|
||
" cute.printf(\">?? Zipped Divide Result: {}\", result)\n",
|
||
"\n",
|
||
"zipped_divide_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- Tiled Divide :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (9,(4,8)):(59,(13,1))\n",
|
||
">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc9581610d0>, <cutlass.cute.core._Layout object at 0x7fc958161070>)\n",
|
||
">>> Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))\n",
|
||
">?? Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def tiled_divide_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates tiled divide :\n",
|
||
" Layout Shape : (M, N, L, ...)\n",
|
||
" Tiler Shape : <TileM, TileN>\n",
|
||
" Result Shape : ((TileM,TileN), RestM, RestN, L, ...)\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n",
|
||
" cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n",
|
||
" \n",
|
||
" # Apply tiled divide\n",
|
||
" result = cute.tiled_divide(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Tiled Divide Result:\", result)\n",
|
||
" cute.printf(\">?? Tiled Divide Result: {}\", result)\n",
|
||
"\n",
|
||
"tiled_divide_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- Flat Divide :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (9,(4,8)):(59,(13,1))\n",
|
||
">>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc958161430>, <cutlass.cute.core._Layout object at 0x7fc9581610d0>)\n",
|
||
">>> Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))\n",
|
||
">?? Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def flat_divide_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates flat divide :\n",
|
||
" Layout Shape : (M, N, L, ...)\n",
|
||
" Tiler Shape : <TileM, TileN>\n",
|
||
" Result Shape : (TileM, TileN, RestM, RestN, L, ...)\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1))) # (9,(4,8)):(59,(13,1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = (cute.make_layout(3, stride=3), # Apply to mode-0 layout 3:3\n",
|
||
" cute.make_layout((2, 4), stride=(1, 8))) # Apply to mode-1 layout (2,4):(1,8)\n",
|
||
" \n",
|
||
" # Apply flat divide\n",
|
||
" result = cute.flat_divide(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Flat Divide Result:\", result)\n",
|
||
" cute.printf(\">?? Flat Divide Result: {}\", result)\n",
|
||
"\n",
|
||
"flat_divide_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 4. Product (Reproducing a Tile)\n",
|
||
"\n",
|
||
"The Product operation in CuTe is used to reproduce one layout according to another layout. It creates a new layout where:\n",
|
||
"- The first mode is the original layout A.\n",
|
||
"- The second mode is a restrided layout B that points to the origin of a \"unique replication\" of A.\n",
|
||
"\n",
|
||
"This is particularly useful for repeating layouts of threads across a tile of data for creating \"repeat\" patterns.\n",
|
||
"\n",
|
||
"#### Examples\n",
|
||
"\n",
|
||
"- Logical Product :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (2,2):(4,1)\n",
|
||
">>> Tiler : 6:1\n",
|
||
">>> Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))\n",
|
||
">?? Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def logical_product_1d_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates 1D logical product\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((2, 2), stride=(4, 1)) # (2,2):(4,1)\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = cute.make_layout(6, stride=1) # Apply to layout 6:1\n",
|
||
" \n",
|
||
" # Apply logical product\n",
|
||
" result = cute.logical_product(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Logical Product Result:\", result)\n",
|
||
" cute.printf(\">?? Logical Product Result: {}\", result)\n",
|
||
"\n",
|
||
"logical_product_1d_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- Blocked and Raked Product :\n",
|
||
" \n",
|
||
" - Blocked Product: Combines the modes of A and B in a block-like fashion, preserving the semantic meaning of the modes by reassociating them after the product.\n",
|
||
" - Raked Product: Combines the modes of A and B in an interleaved or \"raked\" fashion, creating a cyclic distribution of the tiles."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (2,5):(5,1)\n",
|
||
">>> Tiler : (3,4):(1,3)\n",
|
||
">>> Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))\n",
|
||
">>> Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))\n",
|
||
">?? Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))\n",
|
||
">?? Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def blocked_raked_product_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates blocked and raked products\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((2, 5), stride=(5, 1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = cute.make_layout((3, 4), stride=(1, 3))\n",
|
||
" \n",
|
||
" # Apply blocked product\n",
|
||
" blocked_result = cute.blocked_product(layout, tiler=tiler)\n",
|
||
"\n",
|
||
" # Apply raked product\n",
|
||
" raked_result = cute.raked_product(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Blocked Product Result:\", blocked_result)\n",
|
||
" print(\">>> Raked Product Result:\", raked_result)\n",
|
||
" cute.printf(\">?? Blocked Product Result: {}\", blocked_result)\n",
|
||
" cute.printf(\">?? Raked Product Result: {}\", raked_result)\n",
|
||
"\n",
|
||
"blocked_raked_product_example()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- Zipped, tiled, and flat product :\n",
|
||
" \n",
|
||
" - Similar to divide operations, zipped, tiled, and flat product are flavors of `logical_product` that potentially rearrange modes into more convenient forms."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
">>> Layout: (2,5):(5,1)\n",
|
||
">>> Tiler : (3,4):(1,3)\n",
|
||
">>> Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))\n",
|
||
">>> Tiled Product Result: ((2,5),3,4):((5,1),10,30)\n",
|
||
">>> Flat Product Result: (2,5,3,4):(5,1,10,30)\n",
|
||
">?? Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))\n",
|
||
">?? Tiled Product Result: ((2,5),3,4):((5,1),10,30)\n",
|
||
">?? Flat Product Result: (2,5,3,4):(5,1,10,30)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"@cute.jit\n",
|
||
"def zipped_tiled_flat_product_example():\n",
|
||
" \"\"\"\n",
|
||
" Demonstrates zipped, tiled, and flat products\n",
|
||
" Layout Shape : (M, N, L, ...)\n",
|
||
" Tiler Shape : <TileM, TileN>\n",
|
||
"\n",
|
||
" zipped_product : ((M,N), (TileM,TileN,L,...))\n",
|
||
" tiled_product : ((M,N), TileM, TileN, L, ...)\n",
|
||
" flat_product : (M, N, TileM, TileN, L, ...)\n",
|
||
" \"\"\"\n",
|
||
" # Define the original layout\n",
|
||
" layout = cute.make_layout((2, 5), stride=(5, 1))\n",
|
||
" \n",
|
||
" # Define the tiler\n",
|
||
" tiler = cute.make_layout((3, 4), stride=(1, 3))\n",
|
||
"\n",
|
||
" # Apply zipped product\n",
|
||
" zipped_result = cute.zipped_product(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Apply tiled product\n",
|
||
" tiled_result = cute.tiled_product(layout, tiler=tiler)\n",
|
||
" \n",
|
||
" # Apply flat product\n",
|
||
" flat_result = cute.flat_product(layout, tiler=tiler)\n",
|
||
"\n",
|
||
" # Print results\n",
|
||
" print(\">>> Layout:\", layout)\n",
|
||
" print(\">>> Tiler :\", tiler)\n",
|
||
" print(\">>> Zipped Product Result:\", zipped_result)\n",
|
||
" print(\">>> Tiled Product Result:\", tiled_result)\n",
|
||
" print(\">>> Flat Product Result:\", flat_result)\n",
|
||
" cute.printf(\">?? Zipped Product Result: {}\", zipped_result)\n",
|
||
" cute.printf(\">?? Tiled Product Result: {}\", tiled_result)\n",
|
||
" cute.printf(\">?? Flat Product Result: {}\", flat_result)\n",
|
||
"\n",
|
||
"zipped_tiled_flat_product_example()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pythondsl_venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|