v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -399,6 +399,70 @@
"\n",
"tensor_print_example3()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def print_tensor_gpu(src: cute.Tensor):\n",
" print(src)\n",
" cute.print_tensor(src)\n",
"\n",
"@cute.jit\n",
"def print_tensor_host(src: cute.Tensor):\n",
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
" [[-0.690547, -0.274619, -1.659539, ],\n",
" [-1.843524, -1.648711, 1.163431, ],\n",
" [-0.716668, -1.900705, 0.592515, ],\n",
" [ 0.711333, -0.552422, 0.860237, ]])\n"
]
}
],
"source": [
"import torch\n",
"def tensor_print_example4():\n",
" a = torch.randn(4, 3, device=\"cuda\")\n",
" cutlass.cuda.initialize_cuda_context()\n",
" print_tensor_host(from_dlpack(a))\n",
"\n",
"tensor_print_example4()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
]
}
],
"metadata": {

View File

@ -256,16 +256,6 @@
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
"\n",
"@cute.kernel\n",
"def print_tensor_gpu(ptr: cute.Pointer):\n",
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
" tensor = cute.make_tensor(ptr, layout)\n",
"\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
"\n",
" if tidx == 0:\n",
" cute.print_tensor(tensor)\n",
"\n",
"\n",
"# Create a tensor with sequential data using torch\n",
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",