340
examples/python/00_basic_gemm.ipynb
Normal file
340
examples/python/00_basic_gemm.ipynb
Normal file
@ -0,0 +1,340 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ef96b3f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Basic example of using the CUTLASS Python interface\n",
|
||||
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "962324fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e324219",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"\n",
|
||||
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
|
||||
"# omit this information.\n",
|
||||
"print_module = True\n",
|
||||
"\n",
|
||||
"m = 128\n",
|
||||
"n = m\n",
|
||||
"k = m\n",
|
||||
"\n",
|
||||
"dtype = np.float16\n",
|
||||
"type_A = np.float16\n",
|
||||
"type_B = np.float16\n",
|
||||
"type_C = np.float16\n",
|
||||
"type_D = np.float16\n",
|
||||
"\n",
|
||||
"np.random.seed(1234)\n",
|
||||
"random.seed(1234)\n",
|
||||
"scope_min = -4\n",
|
||||
"scope_max = 4\n",
|
||||
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
|
||||
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
|
||||
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
|
||||
"\n",
|
||||
"alpha = np.float16(1.)\n",
|
||||
"beta = np.float16(0.)\n",
|
||||
"\n",
|
||||
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f2c7bf48",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Declaring and running a GEMM\n",
|
||||
"To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n",
|
||||
"This sets up a default GEMM operation for the given device on which you are running.\n",
|
||||
"\n",
|
||||
"Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n",
|
||||
"\n",
|
||||
"Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0dfd8975",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n",
|
||||
"# specifying `element_accumulator` is not required if it is the same as `element`\n",
|
||||
"plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a5856de",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "945478ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We then compare the output to running the GEMM using NumPy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6b669de6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n",
|
||||
"np.testing.assert_array_equal(tensor_D, tensor_D_numpy)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee5cbbbe",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6c86493",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Changing operation modes\n",
|
||||
"By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n",
|
||||
"\n",
|
||||
"The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "529fda93",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(plan.opclass)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d27c575",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n",
|
||||
"\n",
|
||||
"As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n",
|
||||
"\n",
|
||||
"Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a44d35b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n",
|
||||
"plan.opclass = cutlass.OpcodeClass.Simt\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "639dcb59",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b480853",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.testing.assert_array_equal(tensor_D, tensor_D_simt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0cce1eae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running cached kernels\n",
|
||||
"You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n",
|
||||
"\n",
|
||||
"CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f8051e5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"m = 2400\n",
|
||||
"n = 3232\n",
|
||||
"k = 4096\n",
|
||||
"\n",
|
||||
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
|
||||
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
|
||||
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
|
||||
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n",
|
||||
"\n",
|
||||
"alpha = np.float16(1.)\n",
|
||||
"beta = np.float16(2.)\n",
|
||||
"\n",
|
||||
"plan.opclass = cutlass.OpcodeClass.TensorOp\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "52a4e318",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running non-default GEMMs\n",
|
||||
"The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n",
|
||||
"\n",
|
||||
"Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1c593be1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tiles = plan.tile_descriptions()\n",
|
||||
"print('{} tile descriptions returned'.format(len(tiles)))\n",
|
||||
"num_print = 10\n",
|
||||
"print('First {} tile descriptions are:'.format(num_print))\n",
|
||||
"for td in tiles[:num_print]:\n",
|
||||
" print(td)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dc3ad875",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll pick one of these configurations at random and compile and run it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a8dc5287",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"idx = random.randint(0, len(tiles)-1)\n",
|
||||
"td = tiles[idx]\n",
|
||||
"print('Tile description {} is: {}'.format(idx, td))\n",
|
||||
"plan.compile(td)\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5a8b534",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5e88d17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Stream K is only supported pre-SM90 (at least when this example was written)\n",
|
||||
"if plan.cc != 90:\n",
|
||||
" plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n",
|
||||
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5a8ba2ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Handling errors\n",
|
||||
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
|
||||
"\n",
|
||||
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe7d0e42",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# td = tiles[0]\n",
|
||||
"# td.stages = 8\n",
|
||||
"# plan.compile(td)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
202
examples/python/01_epilogue.ipynb
Normal file
202
examples/python/01_epilogue.ipynb
Normal file
@ -0,0 +1,202 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "5d24a692",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Example of using elementwise activation functions in the CUTLASS Python interface\n",
|
||||
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3ca993fe",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63a70a3c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"\n",
|
||||
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
|
||||
"# omit this information.\n",
|
||||
"print_module = True\n",
|
||||
"\n",
|
||||
"m = 256\n",
|
||||
"n = m\n",
|
||||
"k = m\n",
|
||||
"\n",
|
||||
"type_A = np.float16\n",
|
||||
"type_B = np.float16\n",
|
||||
"type_C = np.float16\n",
|
||||
"type_D = np.float16\n",
|
||||
"\n",
|
||||
"np.random.seed(1234)\n",
|
||||
"scope_min = -4\n",
|
||||
"scope_max = 4\n",
|
||||
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
|
||||
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
|
||||
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
|
||||
"\n",
|
||||
"alpha = np.float16(1.)\n",
|
||||
"beta = np.float16(0.)\n",
|
||||
"\n",
|
||||
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1eb0d95b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run a GEMM with an identity activation function\n",
|
||||
"To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d257833",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "54961694",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run a GEMM with a ReLU element-wise activation function\n",
|
||||
"CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n",
|
||||
"```\n",
|
||||
"D = alpha * (A @ B) + beta * C\n",
|
||||
"D = act(D)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n",
|
||||
"\n",
|
||||
"This is easy to do in CUTLASS. One only needs to set the plan's `activation` field."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fe49443",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n",
|
||||
"plan.activation = cutlass.epilogue.relu\n",
|
||||
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "455d0a37",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now verify that the result of the GEMM that used a ReLU activation function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e32e7798",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n",
|
||||
"np.testing.assert_array_equal(relu_ref, tensor_D_relu)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf959171",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Other element-wise activation functions\n",
|
||||
"CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e17d730",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"activations = plan.activations()\n",
|
||||
"for activation in activations:\n",
|
||||
" print(activation)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e4599fa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can then run each of them:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9c3598c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for activation in activations:\n",
|
||||
" print('=============================================================================================')\n",
|
||||
" print(f'Compiling and running activation {activation}')\n",
|
||||
" print('=============================================================================================')\n",
|
||||
" plan.activation = activation\n",
|
||||
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "751f8d92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
264
examples/python/02_pytorch_extension_grouped_gemm.ipynb
Normal file
264
examples/python/02_pytorch_extension_grouped_gemm.ipynb
Normal file
@ -0,0 +1,264 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "6acbea5d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n",
|
||||
"This notebook walks through a basic example of using the CUTLASS Python interface to declare\n",
|
||||
"a grouped GEMM kernel and export it as a PyTorch CUDA extension.\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n",
|
||||
"\n",
|
||||
"## Background on grouped GEMM\n",
|
||||
"Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n",
|
||||
"in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n",
|
||||
"without the requirement that the sizes and strides of each GEMM be the same.\n",
|
||||
"\n",
|
||||
"For example, if one has `p` GEMMs with sizes:\n",
|
||||
"```text\n",
|
||||
"M_1 x N_1 x K_1\n",
|
||||
"M_2 x N_2 x K_2\n",
|
||||
"...\n",
|
||||
"M_p x N_p x K_p\n",
|
||||
"```\n",
|
||||
"CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n",
|
||||
"\n",
|
||||
"Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n",
|
||||
"insufficiently utilize the device in isolation.\n",
|
||||
"\n",
|
||||
"## Declaring a grouped GEMM via the CUTLASS Python interface\n",
|
||||
"A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n",
|
||||
"simply calls `cutlass.op.GroupedGemm`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fdcf21d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"dtype = torch.float16\n",
|
||||
"plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "514f40a4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c2a7371e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"random.seed(2023)\n",
|
||||
"\n",
|
||||
"# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n",
|
||||
"def initialize(dtype, M, N, K):\n",
|
||||
" sizes = [(M, K), (K, N), (M, N), (M, N)]\n",
|
||||
" return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n",
|
||||
"\n",
|
||||
"# Utility function to generate `problems` GEMMs of random sizes\n",
|
||||
"def generate_problems(problems):\n",
|
||||
" valid_sizes = [128, 256, 512, 1024]\n",
|
||||
" As, Bs, Cs, Ds = [], [], [], []\n",
|
||||
" for _ in range(problems):\n",
|
||||
" M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n",
|
||||
" A, B, C, D = initialize(dtype, M, N, K)\n",
|
||||
" As.append(A)\n",
|
||||
" Bs.append(B)\n",
|
||||
" Cs.append(C)\n",
|
||||
" Ds.append(D)\n",
|
||||
" return As, Bs, Cs, Ds"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "590a3bc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "776c9233",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"As, Bs, Cs, Ds, = generate_problems(50)\n",
|
||||
"\n",
|
||||
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
|
||||
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
"\n",
|
||||
"for d, d_torch in zip(Ds, Ds_torch):\n",
|
||||
" assert torch.allclose(d, d_torch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "766e4f03",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
|
||||
"The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
|
||||
"\n",
|
||||
"The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
|
||||
"\n",
|
||||
"To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a98dee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"op = plan.construct()\n",
|
||||
"grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c8ca3991",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `cutlass.emit.pytorch` function emits:\n",
|
||||
"* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n",
|
||||
"* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n",
|
||||
"* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n",
|
||||
"\n",
|
||||
"The extension can be build from within the `module_output` directory by running:\n",
|
||||
"```bash\n",
|
||||
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n",
|
||||
"```\n",
|
||||
"Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
|
||||
"\n",
|
||||
"See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
|
||||
"\n",
|
||||
"The PyTorch CUDA extension could be built for this module by running:\n",
|
||||
"```bash\n",
|
||||
"cd out\n",
|
||||
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
|
||||
"```\n",
|
||||
"(assuming that one is building for SM80)\n",
|
||||
"\n",
|
||||
"One could then use the kernel in a later PyTorch module by running:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import torch\n",
|
||||
"import grouped_gemm\n",
|
||||
"\n",
|
||||
"grouped_gemm.run(As, Bs)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
|
||||
"Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
|
||||
"and returns back the loaded extension.\n",
|
||||
"\n",
|
||||
"We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cecb26a4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Ds = grouped_gemm.run(As, Bs)\n",
|
||||
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
"for d, d_torch in zip(Ds, Ds_torch):\n",
|
||||
" assert torch.allclose(d, d_torch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50db80e4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, we can profile our grouped GEMM extension:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b76805d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_warmup = 20\n",
|
||||
"num_profile = 100\n",
|
||||
"\n",
|
||||
"# Warmup iterations\n",
|
||||
"for _ in range(num_warmup):\n",
|
||||
" Ds = grouped_gemm.run(As, Bs)\n",
|
||||
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
"# Timing iterations\n",
|
||||
"import time\n",
|
||||
"grouped = 0\n",
|
||||
"nongrouped = 0\n",
|
||||
"for _ in range(num_profile):\n",
|
||||
" start = time.time()\n",
|
||||
" Ds = grouped_gemm.run(As, Bs)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" grouped += time.time() - start\n",
|
||||
"\n",
|
||||
" start = time.time()\n",
|
||||
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" nongrouped += time.time() - start\n",
|
||||
"\n",
|
||||
"print('Grouped: {:.3f} us'.format(grouped * 1e6/num_profile))\n",
|
||||
"print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n",
|
||||
"print('Speedup: {:.3f}'.format(nongrouped / grouped))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f22fc696",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
14
examples/python/README.md
Normal file
14
examples/python/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Examples of using the CUTLASS Python interface
|
||||
|
||||
* [00_basic_gemm](/examples/python/00_basic_gemm.ipynb)
|
||||
|
||||
Shows how declare, configure, compile, and run a CUTLASS GEMM using the Python interface
|
||||
|
||||
* [01_epilogue](/examples/python/01_epilogue.ipynb)
|
||||
|
||||
Shows how to fuse elementwise activation functions to GEMMs via the Python interface
|
||||
|
||||
* [02_pytorch_extension_grouped_gemm](/examples/python/02_pytorch_extension_grouped_gemm.ipynb)
|
||||
|
||||
Shows how to declare, compile, and run a grouped GEMM operation via the Python interface,
|
||||
along with how the emitted kernel can be easily exported to a PyTorch CUDA extension.
|
||||
Reference in New Issue
Block a user