CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@ -357,7 +357,7 @@
|
||||
"## Handling errors\n",
|
||||
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
|
||||
"\n",
|
||||
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
|
||||
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -371,6 +371,75 @@
|
||||
"# td.stages = 8\n",
|
||||
"# plan.compile(td)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specializations for other data types\n",
|
||||
"\n",
|
||||
"Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n",
|
||||
"\n",
|
||||
"For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.backend.utils.device import device_cc\n",
|
||||
"\n",
|
||||
"# 3xTF32 requires SM80 or higher\n",
|
||||
"if device_cc() >= 80:\n",
|
||||
" plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n",
|
||||
" plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n",
|
||||
"\n",
|
||||
" # Create input/output tensors in FP32\n",
|
||||
" A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n",
|
||||
" C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n",
|
||||
"\n",
|
||||
" # Run the GEMM\n",
|
||||
" plan.run(A, B, C, D, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" import torch\n",
|
||||
"except ImportError:\n",
|
||||
" print(\"PyTorch is not available. Skipping FP8 example\")\n",
|
||||
" import sys; sys.exit(0)\n",
|
||||
"\n",
|
||||
"if not hasattr(torch, \"float8_e4m3fn\"):\n",
|
||||
" print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n",
|
||||
" import sys; sys.exit(0)\n",
|
||||
"\n",
|
||||
"# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n",
|
||||
"if device_cc() >= 90:\n",
|
||||
" plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n",
|
||||
" layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n",
|
||||
" layout_C=cutlass.LayoutType.ColumnMajor)\n",
|
||||
"\n",
|
||||
" # Create input/output tensors in FP8\n",
|
||||
" A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
|
||||
" C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n",
|
||||
"\n",
|
||||
" # Run the GEMM\n",
|
||||
" plan.run(A, B, C, D, print_module=print_module)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -134,7 +134,7 @@
|
||||
"id": "590a3bc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
|
||||
"We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -144,7 +144,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"As, Bs, Cs, Ds, = generate_problems(50)\n",
|
||||
"As, Bs, Cs, Ds, = generate_problems(20)\n",
|
||||
"\n",
|
||||
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
|
||||
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
|
||||
Reference in New Issue
Block a user