490 lines
14 KiB
Plaintext
490 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Printing with CuTe DSL\n",
|
|
"\n",
|
|
"This notebook demonstrates the different ways to print values in CuTe and explains the important distinction between static (compile-time) and dynamic (runtime) values.\n",
|
|
"\n",
|
|
"## Key Concepts\n",
|
|
"- Static values: Known at compile time\n",
|
|
"- Dynamic values: Only known at runtime\n",
|
|
"- Different printing methods for different scenarios\n",
|
|
"- Layout representation in CuTe\n",
|
|
"- Tensor visualization and formatting"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import cutlass\n",
|
|
"import cutlass.cute as cute\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Print Example Function\n",
|
|
"\n",
|
|
"The `print_example` function demonstrates several important concepts:\n",
|
|
"\n",
|
|
"### 1. Python's `print` vs CuTe's `cute.printf`\n",
|
|
"- `print`: Can only show static values at compile time\n",
|
|
"- `cute.printf`: Can display both static and dynamic values at runtime\n",
|
|
"\n",
|
|
"### 2. Value Types\n",
|
|
"- `a`: Dynamic `Int32` value (runtime)\n",
|
|
"- `b`: Static `Constexpr[int]` value (compile-time)\n",
|
|
"\n",
|
|
"### 3. Layout Printing\n",
|
|
"Shows how layouts are represented differently in static vs dynamic contexts:\n",
|
|
"- Static context: Unknown values shown as `?`\n",
|
|
"- Dynamic context: Actual values displayed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@cute.jit\n",
|
|
"def print_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
|
" \"\"\"\n",
|
|
" Demonstrates different printing methods in CuTe and how they handle static vs dynamic values.\n",
|
|
"\n",
|
|
" This example shows:\n",
|
|
" 1. How Python's `print` function works with static values at compile time but can't show dynamic values\n",
|
|
" 2. How `cute.printf` can display both static and dynamic values at runtime\n",
|
|
" 3. The difference between types in static vs dynamic contexts\n",
|
|
" 4. How layouts are represented in both printing methods\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" a: A dynamic Int32 value that will be determined at runtime\n",
|
|
" b: A static (compile-time constant) integer value\n",
|
|
" \"\"\"\n",
|
|
" # Use Python `print` to print static information\n",
|
|
" print(\">>>\", b) # => 2\n",
|
|
" # `a` is dynamic value\n",
|
|
" print(\">>>\", a) # => ?\n",
|
|
"\n",
|
|
" # Use `cute.printf` to print dynamic information\n",
|
|
" cute.printf(\">?? {}\", a) # => 8\n",
|
|
" cute.printf(\">?? {}\", b) # => 2\n",
|
|
"\n",
|
|
" print(\">>>\", type(a)) # => <class 'cutlass.Int32'>\n",
|
|
" print(\">>>\", type(b)) # => <class 'int'>\n",
|
|
"\n",
|
|
" layout = cute.make_layout((a, b))\n",
|
|
" print(\">>>\", layout) # => (?,2):(1,?)\n",
|
|
" cute.printf(\">?? {}\", layout) # => (8,2):(1,8)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Compile and Run\n",
|
|
"\n",
|
|
"**Direct Compilation and Run**\n",
|
|
" - `print_example(cutlass.Int32(8), 2)`\n",
|
|
" - Compiles and runs in one step will execute both static and dynamic print\n",
|
|
" * `>>>` stands for static print\n",
|
|
" * `>??` stands for dynamic print"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
">>> 2\n",
|
|
">>> ?\n",
|
|
">>> Int32\n",
|
|
">>> <class 'int'>\n",
|
|
">>> (?,2):(1,?)\n",
|
|
">?? 8\n",
|
|
">?? 2\n",
|
|
">?? (8,2):(1,8)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print_example(cutlass.Int32(8), 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Compile Function\n",
|
|
"\n",
|
|
"When compiles the function with `cute.compile(print_example, cutlass.Int32(8), 2)`, Python interpreter \n",
|
|
"traces code and only evaluate static expression and print static information."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
">>> 2\n",
|
|
">>> ?\n",
|
|
">>> Int32\n",
|
|
">>> <class 'int'>\n",
|
|
">>> (?,2):(1,?)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print_example_compiled = cute.compile(print_example, cutlass.Int32(8), 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Call compiled function\n",
|
|
"\n",
|
|
"Only print out runtime information"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
">?? 8\n",
|
|
">?? 2\n",
|
|
">?? (8,2):(1,8)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print_example_compiled(cutlass.Int32(8))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Format String Example\n",
|
|
"\n",
|
|
"The `format_string_example` function shows an important limitation:\n",
|
|
"- F-strings in CuTe are evaluated at compile time\n",
|
|
"- This means dynamic values won't show their runtime values in f-strings\n",
|
|
"- Use `cute.printf` when you need to see runtime values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Direct run output:\n",
|
|
"a: ?, b: 2\n",
|
|
"layout: (?,2):(1,?)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@cute.jit\n",
|
|
"def format_string_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
|
" \"\"\"\n",
|
|
" Format string is evaluated at compile time.\n",
|
|
" \"\"\"\n",
|
|
" print(f\"a: {a}, b: {b}\")\n",
|
|
"\n",
|
|
" layout = cute.make_layout((a, b))\n",
|
|
" print(f\"layout: {layout}\")\n",
|
|
"\n",
|
|
"print(\"Direct run output:\")\n",
|
|
"format_string_example(cutlass.Int32(8), 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Printing Tensor Examples\n",
|
|
"\n",
|
|
"CuTe provides specialized functionality for printing tensors through the `print_tensor` operation. The `cute.print_tensor` takes the following parameter:\n",
|
|
"- `Tensor` (required): A CuTe tensor object that you want to print. The tensor must support load and store operations\n",
|
|
"- `verbose` (optional, default=False): A boolean flag that controls the level of detail in the output. When set to True, it will print indices details for each element in the tensor.\n",
|
|
"\n",
|
|
"Below example code shows the difference between verbose ON and OFF, and how to print a sub range of the given tensor."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from cutlass.cute.runtime import from_dlpack\n",
|
|
"\n",
|
|
"@cute.jit\n",
|
|
"def print_tensor_basic(x : cute.Tensor):\n",
|
|
" # Print the tensor\n",
|
|
" print(\"Basic output:\")\n",
|
|
" cute.print_tensor(x)\n",
|
|
" \n",
|
|
"@cute.jit\n",
|
|
"def print_tensor_verbose(x : cute.Tensor):\n",
|
|
" # Print the tensor with verbose mode\n",
|
|
" print(\"Verbose output:\")\n",
|
|
" cute.print_tensor(x, verbose=True)\n",
|
|
"\n",
|
|
"@cute.jit\n",
|
|
"def print_tensor_slice(x : cute.Tensor, coord : tuple):\n",
|
|
" # slice a 2D tensor from the 3D tensor\n",
|
|
" sliced_data = cute.slice_(x, coord)\n",
|
|
" y = cute.make_fragment(sliced_data.layout, sliced_data.element_type)\n",
|
|
" # Convert to TensorSSA format by loading the sliced data into the fragment\n",
|
|
" y.store(sliced_data.load())\n",
|
|
" print(\"Slice output:\")\n",
|
|
" cute.print_tensor(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The default `cute.print_tensor` will output CuTe tensor with datatype, storage space, CuTe layout information, and print data in torch-style format."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Basic output:\n",
|
|
"tensor(raw_ptr(0x000000000a5f1d50: f32, generic, align<4>) o (4,3,2):(6,2,1), data=\n",
|
|
" [[[ 0.000000, 2.000000, 4.000000, ],\n",
|
|
" [ 6.000000, 8.000000, 10.000000, ],\n",
|
|
" [ 12.000000, 14.000000, 16.000000, ],\n",
|
|
" [ 18.000000, 20.000000, 22.000000, ]],\n",
|
|
"\n",
|
|
" [[ 1.000000, 3.000000, 5.000000, ],\n",
|
|
" [ 7.000000, 9.000000, 11.000000, ],\n",
|
|
" [ 13.000000, 15.000000, 17.000000, ],\n",
|
|
" [ 19.000000, 21.000000, 23.000000, ]]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def tensor_print_example1():\n",
|
|
" shape = (4, 3, 2)\n",
|
|
" \n",
|
|
" # Creates [0,...,23] and reshape to (4, 3, 2)\n",
|
|
" data = np.arange(24, dtype=np.float32).reshape(*shape) \n",
|
|
" \n",
|
|
" print_tensor_basic(from_dlpack(data))\n",
|
|
"\n",
|
|
"tensor_print_example1()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The verbosed print will show coodination details of each element in the tensor. The below example shows how we index element in a 2D 4x3 tensor space."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Verbose output:\n",
|
|
"tensor(raw_ptr(0x000000000a814cc0: f32, generic, align<4>) o (4,3):(3,1), data= (\n",
|
|
"\t(0,0)= 0.000000\n",
|
|
"\t(0,1)= 1.000000\n",
|
|
"\t(0,2)= 2.000000\n",
|
|
"\t(1,0)= 3.000000\n",
|
|
"\t(1,1)= 4.000000\n",
|
|
"\t(1,2)= 5.000000\n",
|
|
"\t(2,0)= 6.000000\n",
|
|
"\t(2,1)= 7.000000\n",
|
|
"\t(2,2)= 8.000000\n",
|
|
"\t(3,0)= 9.000000\n",
|
|
"\t(3,1)= 10.000000\n",
|
|
"\t(3,2)= 11.000000\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def tensor_print_example2():\n",
|
|
" shape = (4, 3)\n",
|
|
" \n",
|
|
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
|
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
|
" \n",
|
|
" print_tensor_verbose(from_dlpack(data))\n",
|
|
"\n",
|
|
"tensor_print_example2()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To print a subset elements in the given Tensor, we can use cute.slice_ to select a range of the given tensor, load them into register and then print the values with `cute.print_tensor`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Slice output:\n",
|
|
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (4):(3), data=\n",
|
|
" [ 0.000000, ],\n",
|
|
" [ 3.000000, ],\n",
|
|
" [Slice output:\n",
|
|
" 6.000000, ],\n",
|
|
" [ 9.000000, ])\n",
|
|
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (3):(1), data=\n",
|
|
" [ 3.000000, ],\n",
|
|
" [ 4.000000, ],\n",
|
|
" [ 5.000000, ])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def tensor_print_example3():\n",
|
|
" shape = (4, 3)\n",
|
|
" \n",
|
|
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
|
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
|
" \n",
|
|
" print_tensor_slice(from_dlpack(data), (None, 0))\n",
|
|
" print_tensor_slice(from_dlpack(data), (1, None))\n",
|
|
"\n",
|
|
"tensor_print_example3()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@cute.kernel\n",
|
|
"def print_tensor_gpu(src: cute.Tensor):\n",
|
|
" print(src)\n",
|
|
" cute.print_tensor(src)\n",
|
|
"\n",
|
|
"@cute.jit\n",
|
|
"def print_tensor_host(src: cute.Tensor):\n",
|
|
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
|
|
" [[-0.690547, -0.274619, -1.659539, ],\n",
|
|
" [-1.843524, -1.648711, 1.163431, ],\n",
|
|
" [-0.716668, -1.900705, 0.592515, ],\n",
|
|
" [ 0.711333, -0.552422, 0.860237, ]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"def tensor_print_example4():\n",
|
|
" a = torch.randn(4, 3, device=\"cuda\")\n",
|
|
" cutlass.cuda.initialize_cuda_context()\n",
|
|
" print_tensor_host(from_dlpack(a))\n",
|
|
"\n",
|
|
"tensor_print_example4()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|