Files
cutlass/examples/python/CuTeDSL/notebooks/tensorssa.ipynb
2025-07-03 08:07:53 -04:00

559 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cutlass\n",
"import cutlass.cute as cute\n",
"from cutlass.cute.runtime import from_dlpack\n",
"\n",
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction to the TensorSSA in CuTe DSL\n",
"\n",
"This tutorial introduces what is the `TensorSSA` and why we need it. We also give some examples to show how to use `TensorSSA`.\n",
"\n",
"## What is TensorSSA\n",
"\n",
"`TensorSSA` is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.\n",
"\n",
"## Why TensorSSA\n",
"\n",
"`TensorSSA` encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like `+`, `-`, `*`, `/`, `[]`, etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.\n",
"\n",
"It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.\n",
"\n",
"## When to use TensorSSA\n",
"\n",
"`TensorSSA` is primarily used in the following scenarios:\n",
"\n",
"### Load from memory and store to memory"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
"b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
"tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
" [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
" [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
" [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n"
]
}
],
"source": [
"@cute.jit\n",
"def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
" \"\"\"\n",
" Load data from memory and store the result to memory.\n",
"\n",
" :param res: The destination tensor to store the result.\n",
" :param a: The source tensor to be loaded.\n",
" :param b: The source tensor to be loaded.\n",
" \"\"\"\n",
" a_vec = a.load()\n",
" print(f\"a_vec: {a_vec}\") # prints `a_vec: vector<12xf32> o (3, 4)`\n",
" b_vec = b.load()\n",
" print(f\"b_vec: {b_vec}\") # prints `b_vec: vector<12xf32> o (3, 4)`\n",
" res.store(a_vec + b_vec)\n",
" cute.print_tensor(res)\n",
"\n",
"a = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
"b = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
"c = np.zeros(12).reshape((3, 4)).astype(np.float32)\n",
"load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Register-Level Tensor Operations\n",
"\n",
"When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
"tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
" [[ 3.000000, 4.000000, 5.000000, ],\n",
" [ 9.000000, 10.000000, 11.000000, ],\n",
" [ 15.000000, 16.000000, 17.000000, ],\n",
" [ 21.000000, 22.000000, 23.000000, ]])\n"
]
}
],
"source": [
"@cute.jit\n",
"def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
" \"\"\"\n",
" Apply slice operation on the src tensor and store the result to the dst tensor.\n",
"\n",
" :param src: The source tensor to be sliced.\n",
" :param dst: The destination tensor to store the result.\n",
" :param indices: The indices to slice the source tensor.\n",
" \"\"\"\n",
" src_vec = src.load()\n",
" dst_vec = src_vec[indices]\n",
" print(f\"{src_vec} -> {dst_vec}\")\n",
" if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):\n",
" dst.store(dst_vec)\n",
" cute.print_tensor(dst)\n",
" else:\n",
" dst[0] = dst_vec\n",
" cute.print_tensor(dst)\n",
"\n",
"def slice_1():\n",
" src_shape = (4, 2, 3)\n",
" dst_shape = (4, 3)\n",
" indices = (None, 1, None)\n",
"\n",
" \"\"\"\n",
" a:\n",
" [[[ 0. 1. 2.]\n",
" [ 3. 4. 5.]]\n",
"\n",
" [[ 6. 7. 8.]\n",
" [ 9. 10. 11.]]\n",
"\n",
" [[12. 13. 14.]\n",
" [15. 16. 17.]]\n",
"\n",
" [[18. 19. 20.]\n",
" [21. 22. 23.]]]\n",
" \"\"\"\n",
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
"\n",
"slice_1()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
"tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
" [ 10.000000, ])\n"
]
}
],
"source": [
"def slice_2():\n",
" src_shape = (4, 2, 3)\n",
" dst_shape = (1,)\n",
" indices = 10\n",
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
"\n",
"slice_2()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Arithmetic Operations\n",
"\n",
"As we mentioned earlier, there're many tensor operations whose operands are `TensorSSA`. And they are all element-wise operations. We give some examples below.\n",
"\n",
"### Binary Operations\n",
"\n",
"For binary operations, the LHS operand is `TensorSSA` and the RHS operand can be either `TensorSSA` or `Numeric`. When the RHS is `Numeric`, it will be broadcast to a `TensorSSA`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 3.000000, ],\n",
" [ 3.000000, ],\n",
" [ 3.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [-1.000000, ],\n",
" [-1.000000, ],\n",
" [-1.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.500000, ],\n",
" [ 0.500000, ],\n",
" [ 0.500000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.000000, ],\n",
" [ 0.000000, ],\n",
" [ 0.000000, ])\n",
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
" [ 1.000000, ],\n",
" [ 1.000000, ],\n",
" [ 1.000000, ])\n"
]
}
],
"source": [
"@cute.jit\n",
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
" a_vec = a.load()\n",
" b_vec = b.load()\n",
"\n",
" add_res = a_vec + b_vec\n",
" res.store(add_res)\n",
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
"\n",
" sub_res = a_vec - b_vec\n",
" res.store(sub_res)\n",
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
"\n",
" mul_res = a_vec * b_vec\n",
" res.store(mul_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" div_res = a_vec / b_vec\n",
" res.store(div_res)\n",
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
"\n",
" floor_div_res = a_vec // b_vec\n",
" res.store(floor_div_res)\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
"\n",
" mod_res = a_vec % b_vec\n",
" res.store(mod_res)\n",
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
"\n",
"\n",
"a = np.empty((3,), dtype=np.float32)\n",
"a.fill(1.0)\n",
"b = np.empty((3,), dtype=np.float32)\n",
"b.fill(2.0)\n",
"res = np.empty((3,), dtype=np.float32)\n",
"binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 3.000000, ],\n",
" [ 3.000000, ],\n",
" [ 3.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [-1.000000, ],\n",
" [-1.000000, ],\n",
" [-1.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.500000, ],\n",
" [ 0.500000, ],\n",
" [ 0.500000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 0.000000, ],\n",
" [ 0.000000, ],\n",
" [ 0.000000, ])\n",
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
" [ 1.000000, ],\n",
" [ 1.000000, ],\n",
" [ 1.000000, ])\n"
]
}
],
"source": [
"@cute.jit\n",
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
" a_vec = a.load()\n",
"\n",
" add_res = a_vec + c\n",
" res.store(add_res)\n",
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
"\n",
" sub_res = a_vec - c\n",
" res.store(sub_res)\n",
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
"\n",
" mul_res = a_vec * c\n",
" res.store(mul_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" div_res = a_vec / c\n",
" res.store(div_res)\n",
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
"\n",
" floor_div_res = a_vec // c\n",
" res.store(floor_div_res)\n",
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
"\n",
" mod_res = a_vec % c\n",
" res.store(mod_res)\n",
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
"\n",
"a = np.empty((3,), dtype=np.float32)\n",
"a.fill(1.0)\n",
"c = 2.0\n",
"res = np.empty((3,), dtype=np.float32)\n",
"binary_op_2(from_dlpack(res), from_dlpack(a), c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[False True False]\n"
]
}
],
"source": [
"@cute.jit\n",
"def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
" a_vec = a.load()\n",
" b_vec = b.load()\n",
"\n",
" gt_res = a_vec > b_vec\n",
" res.store(gt_res)\n",
"\n",
" \"\"\"\n",
" ge_res = a_ >= b_ # [False, True, False]\n",
" lt_res = a_ < b_ # [True, False, True]\n",
" le_res = a_ <= b_ # [True, False, True]\n",
" eq_res = a_ == b_ # [False, False, False]\n",
" \"\"\"\n",
"\n",
"a = np.array([1, 2, 3], dtype=np.float32)\n",
"b = np.array([2, 1, 4], dtype=np.float32)\n",
"res = np.empty((3,), dtype=np.bool_)\n",
"binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
"print(res) # prints [False, True, False]\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[3 0 7]\n"
]
}
],
"source": [
"@cute.jit\n",
"def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
" a_vec = a.load()\n",
" b_vec = b.load()\n",
"\n",
" xor_res = a_vec ^ b_vec\n",
" res.store(xor_res)\n",
"\n",
" # or_res = a_vec | b_vec\n",
" # res.store(or_res) # prints [3, 2, 7]\n",
"\n",
" # and_res = a_vec & b_vec\n",
" # res.store(and_res) # prints [0, 2, 0]\n",
"\n",
"a = np.array([1, 2, 3], dtype=np.int32)\n",
"b = np.array([2, 2, 4], dtype=np.int32)\n",
"res = np.empty((3,), dtype=np.int32)\n",
"binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
"print(res) # prints [3, 0, 7]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Unary Operations"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [ 2.000000, ],\n",
" [ 2.000000, ],\n",
" [ 2.000000, ])\n",
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [-0.756802, ],\n",
" [-0.756802, ],\n",
" [-0.756802, ])\n",
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
" [ 16.000000, ],\n",
" [ 16.000000, ],\n",
" [ 16.000000, ])\n"
]
}
],
"source": [
"@cute.jit\n",
"def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
" a_vec = a.load()\n",
"\n",
" sqrt_res = cute.math.sqrt(a_vec)\n",
" res.store(sqrt_res)\n",
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
"\n",
" sin_res = cute.math.sin(a_vec)\n",
" res.store(sin_res)\n",
" cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n",
"\n",
" exp2_res = cute.math.exp2(a_vec)\n",
" res.store(exp2_res)\n",
" cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n",
"\n",
"a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
"res = np.empty((3,), dtype=np.float32)\n",
"unary_op_1(from_dlpack(res), from_dlpack(a))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Reduction Operation\n",
"\n",
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21.000000\n",
"tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
" [ 6.000000, ],\n",
" [ 15.000000, ])\n",
"tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
" [ 6.000000, ],\n",
" [ 8.000000, ],\n",
" [ 10.000000, ])\n"
]
}
],
"source": [
"@cute.jit\n",
"def reduction_op(a: cute.Tensor):\n",
" \"\"\"\n",
" Apply reduction operation on the src tensor.\n",
"\n",
" :param src: The source tensor to be reduced.\n",
" \"\"\"\n",
" a_vec = a.load()\n",
" red_res = a_vec.reduce(\n",
" cute.ReductionOp.ADD,\n",
" 0.0,\n",
" reduction_profile=0\n",
" )\n",
" cute.printf(red_res) # prints 21.000000\n",
"\n",
" red_res = a_vec.reduce(\n",
" cute.ReductionOp.ADD,\n",
" 0.0,\n",
" reduction_profile=(None, 1)\n",
" )\n",
" # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
" res.store(red_res)\n",
" cute.print_tensor(res) # prints [6.000000, 15.000000]\n",
"\n",
" red_res = a_vec.reduce(\n",
" cute.ReductionOp.ADD,\n",
" 1.0,\n",
" reduction_profile=(1, None)\n",
" )\n",
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
" res.store(red_res)\n",
" cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n",
"\n",
"\n",
"a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
"reduction_op(from_dlpack(a))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}