v4.1 release update v2. (#2481)
This commit is contained in:
@ -399,6 +399,70 @@
|
||||
"\n",
|
||||
"tensor_print_example3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(src: cute.Tensor):\n",
|
||||
" print(src)\n",
|
||||
" cute.print_tensor(src)\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_host(src: cute.Tensor):\n",
|
||||
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[-0.690547, -0.274619, -1.659539, ],\n",
|
||||
" [-1.843524, -1.648711, 1.163431, ],\n",
|
||||
" [-0.716668, -1.900705, 0.592515, ],\n",
|
||||
" [ 0.711333, -0.552422, 0.860237, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"def tensor_print_example4():\n",
|
||||
" a = torch.randn(4, 3, device=\"cuda\")\n",
|
||||
" cutlass.cuda.initialize_cuda_context()\n",
|
||||
" print_tensor_host(from_dlpack(a))\n",
|
||||
"\n",
|
||||
"tensor_print_example4()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -256,16 +256,6 @@
|
||||
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
|
||||
"\n",
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
"\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.print_tensor(tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a tensor with sequential data using torch\n",
|
||||
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",
|
||||
|
||||
Reference in New Issue
Block a user