CUTLASS 3.1 Python interface documentation (#917)

* Add 12.1 Dockerfile

* Add 3.1 docs
This commit is contained in:
Jack Kosaian
2023-04-18 15:11:35 -04:00
committed by GitHub
parent 54bebe417d
commit 9a83bd3381
81 changed files with 18997 additions and 10 deletions

821
python/docs/externals/00_basic_gemm.html vendored Normal file
View File

@ -0,0 +1,821 @@
<!doctype html>
<html class="no-js" lang="en">
<head><meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="next" title="Contributing" href="../contribute.html" /><link rel="prev" title="Installation" href="../install.html" />
<link rel="canonical" href="docs/externals/00_basic_gemm.html" />
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
<title>Basic example of using the CUTLASS Python interface - CUTLASS Python</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
<style>
body {
--color-code-background: #eeffcc;
--color-code-foreground: black;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media not print {
body[data-theme="dark"] {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media (prefers-color-scheme: dark) {
body:not([data-theme="light"]) {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
}
}
</style></head>
<body>
<script>
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
</script>
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
<symbol id="svg-toc" viewBox="0 0 24 24">
<title>Contents</title>
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
</svg>
</symbol>
<symbol id="svg-menu" viewBox="0 0 24 24">
<title>Menu</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
<line x1="3" y1="12" x2="21" y2="12"></line>
<line x1="3" y1="6" x2="21" y2="6"></line>
<line x1="3" y1="18" x2="21" y2="18"></line>
</svg>
</symbol>
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
<title>Expand</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
<polyline points="9 18 15 12 9 6"></polyline>
</svg>
</symbol>
<symbol id="svg-sun" viewBox="0 0 24 24">
<title>Light mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
<circle cx="12" cy="12" r="5"></circle>
<line x1="12" y1="1" x2="12" y2="3"></line>
<line x1="12" y1="21" x2="12" y2="23"></line>
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
<line x1="1" y1="12" x2="3" y2="12"></line>
<line x1="21" y1="12" x2="23" y2="12"></line>
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
</svg>
</symbol>
<symbol id="svg-moon" viewBox="0 0 24 24">
<title>Dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
</svg>
</symbol>
<symbol id="svg-sun-half" viewBox="0 0 24 24">
<title>Auto light/dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<circle cx="12" cy="12" r="9" />
<path d="M13 12h5" />
<path d="M13 15h4" />
<path d="M13 18h1" />
<path d="M13 9h4" />
<path d="M13 6h1" />
</svg>
</symbol>
</svg>
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
<label class="overlay sidebar-overlay" for="__navigation">
<div class="visually-hidden">Hide navigation sidebar</div>
</label>
<label class="overlay toc-overlay" for="__toc">
<div class="visually-hidden">Hide table of contents sidebar</div>
</label>
<div class="page">
<header class="mobile-header">
<div class="header-left">
<label class="nav-overlay-icon" for="__navigation">
<div class="visually-hidden">Toggle site navigation sidebar</div>
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
</label>
</div>
<div class="header-center">
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
</div>
<div class="header-right">
<div class="theme-toggle-container theme-toggle-header">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-header-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
</header>
<aside class="sidebar-drawer">
<div class="sidebar-container">
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
<div class="sidebar-logo-container">
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
</div>
<span class="sidebar-brand-text">CUTLASS Python</span>
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
<input type="hidden" name="check_keywords" value="yes">
<input type="hidden" name="area" value="default">
</form>
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Getting Started</a></li>
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
<ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
<ul class="current">
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current">
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">Basic GEMM</a></li>
<li class="toctree-l2"><a class="reference internal" href="01_epilogue.html">Epilogue</a></li>
<li class="toctree-l2"><a class="reference internal" href="02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
</ul>
</div>
</div>
</div>
</div>
</aside>
<div class="main">
<div class="content">
<div class="article-container">
<a href="#" class="back-to-top muted-link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
</svg>
<span>Back to top</span>
</a>
<div class="content-icon-container">
<div class="theme-toggle-container theme-toggle-content">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-content-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
<article role="main">
<section id="Basic-example-of-using-the-CUTLASS-Python-interface">
<h1>Basic example of using the CUTLASS Python interface<a class="headerlink" href="#Basic-example-of-using-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h1>
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.</p>
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
<p>We first import various packages needed for the example and construct the input and output tensors that will be used in our example.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">import</span> <span class="nn">cutlass</span>
<span class="c1"># This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to</span>
<span class="c1"># omit this information.</span>
<span class="n">print_module</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">m</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">m</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">m</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">1234</span><span class="p">)</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">1234</span><span class="p">)</span>
<span class="n">scope_min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">4</span>
<span class="n">scope_max</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_A</span><span class="p">))</span>
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_B</span><span class="p">))</span>
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_C</span><span class="p">))</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">tensor_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area stderr docutils container">
<div class="highlight"><pre>
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
</pre></div></div>
</div>
<section id="Declaring-and-running-a-GEMM">
<h2>Declaring and running a GEMM<a class="headerlink" href="#Declaring-and-running-a-GEMM" title="Permalink to this heading">#</a></h2>
<p>To get started, one only needs to provide the tensors declared above to the <code class="docutils literal notranslate"><span class="pre">cutlass.op.Gemm</span></code> call. This sets up a default GEMM operation for the given device on which you are running.</p>
<p>Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.</p>
<p>Calling <code class="docutils literal notranslate"><span class="pre">plan.run()</span></code> will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting <code class="docutils literal notranslate"><span class="pre">print_module</span></code> to <code class="docutils literal notranslate"><span class="pre">true</span></code>, the C++ code that is emitted is printed.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,</span>
<span class="c1"># specifying `element_accumulator` is not required if it is the same as `element`</span>
<span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">Gemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">,</span> <span class="n">element_accumulator</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, float, float&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc556070&gt;
</pre></div></div>
</div>
<p>There are many other ways to construct a plan from <code class="docutils literal notranslate"><span class="pre">cutlass.op.Gemm</span></code> (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the <code class="docutils literal notranslate"><span class="pre">cutlass.op.Gemm</span></code> constructor.</p>
<p>We then compare the output to running the GEMM using NumPy.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">tensor_D_numpy</span> <span class="o">=</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">*</span> <span class="p">(</span><span class="n">tensor_A</span> <span class="o">@</span> <span class="n">tensor_B</span><span class="p">))</span> <span class="o">+</span> <span class="p">(</span><span class="n">beta</span> <span class="o">*</span> <span class="n">tensor_C</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_array_equal</span><span class="p">(</span><span class="n">tensor_D</span><span class="p">,</span> <span class="n">tensor_D_numpy</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy.</p>
</section>
<section id="Changing-operation-modes">
<h2>Changing operation modes<a class="headerlink" href="#Changing-operation-modes" title="Permalink to this heading">#</a></h2>
<p>By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to <code class="docutils literal notranslate"><span class="pre">cutlass.op.Gemm</span></code> is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.</p>
<p>The operation mode currently in use can be returned via the <code class="docutils literal notranslate"><span class="pre">plan.opclass</span></code> property. In this case Tensor Core operations.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">plan</span><span class="o">.</span><span class="n">opclass</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
OpcodeClass.TensorOp
</pre></div></div>
</div>
<p>Suppose that we dont want to use Tensor Cores for this GEMM. One can change to using CUTLASSs SIMT GEMMs by setting the plans <code class="docutils literal notranslate"><span class="pre">opclass</span></code> field.</p>
<p>As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASSs SIMT GEMMs.</p>
<p>Also notice that, this time around, we provided tensor parameters to <code class="docutils literal notranslate"><span class="pre">plan.run()</span></code>. One is free to provide different parameters to <code class="docutils literal notranslate"><span class="pre">plan.run()</span></code> than were passed in at the initial call to <code class="docutils literal notranslate"><span class="pre">cutlass.op.Gemm</span></code>, provided that the passed-in tensors have the same data type and layout as those passed in on intialization.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">tensor_D_simt</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">opclass</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">Simt</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D_simt</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1
using cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;128, 128, 8&gt;,
cutlass::gemm::GemmShape&lt;32, 64, 8&gt;,
cutlass::gemm::GemmShape&lt;1, 1, 1&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 1, float, float&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
2,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_type :
public cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b3075abe0&gt;
</pre></div></div>
</div>
<p>If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_array_equal</span><span class="p">(</span><span class="n">tensor_D</span><span class="p">,</span> <span class="n">tensor_D_simt</span><span class="p">)</span>
</pre></div>
</div>
</div>
</section>
<section id="Running-cached-kernels">
<h2>Running cached kernels<a class="headerlink" href="#Running-cached-kernels" title="Permalink to this heading">#</a></h2>
<p>You may have noticed that the <code class="docutils literal notranslate"><span class="pre">plan.run()</span></code> calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.</p>
<p>CUTLASS caches compiled binaries so that recompilation isnt necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call <code class="docutils literal notranslate"><span class="pre">plan.run()</span></code> again (with a different set of tensor parameters), youll find the call to return much faster.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[7]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">m</span> <span class="o">=</span> <span class="mi">2400</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">3232</span>
<span class="n">k</span> <span class="o">=</span> <span class="mi">4096</span>
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_A</span><span class="p">))</span>
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_B</span><span class="p">))</span>
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_C</span><span class="p">))</span>
<span class="n">tensor_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">2.</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">opclass</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">OpcodeClass</span><span class="o">.</span><span class="n">TensorOp</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, float, float&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[7]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b30fb9880&gt;
</pre></div></div>
</div>
</section>
<section id="Running-non-default-GEMMs">
<h2>Running non-default GEMMs<a class="headerlink" href="#Running-non-default-GEMMs" title="Permalink to this heading">#</a></h2>
<p>The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?</p>
<p>Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape).</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[8]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">tiles</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">tile_descriptions</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="si">{}</span><span class="s1"> tile descriptions returned&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">tiles</span><span class="p">)))</span>
<span class="n">num_print</span> <span class="o">=</span> <span class="mi">10</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;First </span><span class="si">{}</span><span class="s1"> tile descriptions are:&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_print</span><span class="p">))</span>
<span class="k">for</span> <span class="n">td</span> <span class="ow">in</span> <span class="n">tiles</span><span class="p">[:</span><span class="n">num_print</span><span class="p">]:</span>
<span class="nb">print</span><span class="p">(</span><span class="n">td</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
132 tile descriptions returned
First 10 tile descriptions are:
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [256, 128, 64]
WarpCount: [4, 2, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [128, 256, 64]
WarpCount: [2, 4, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [256, 128, 64]
WarpCount: [4, 2, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [128, 256, 64]
WarpCount: [2, 4, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [256, 128, 32]
WarpCount: [4, 2, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [128, 256, 32]
WarpCount: [2, 4, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [256, 64, 64]
WarpCount: [4, 1, 1]
Stages: 4
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [64, 256, 64]
WarpCount: [1, 4, 1]
Stages: 4
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [128, 128, 64]
WarpCount: [2, 2, 1]
Stages: 4
Kernel schedule: ScheduleAuto
}
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [256, 64, 64]
WarpCount: [4, 1, 1]
Stages: 3
Kernel schedule: ScheduleAuto
}
</pre></div></div>
</div>
<p>Next, well pick one of these configurations at random and compile and run it.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">tiles</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">td</span> <span class="o">=</span> <span class="n">tiles</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Tile description </span><span class="si">{}</span><span class="s1"> is: </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">td</span><span class="p">))</span>
<span class="n">plan</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">td</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
Tile description 112 is:
{
ClusterShape: [1, 1, 1]
ThreadblockShape: [128, 128, 32]
WarpCount: [2, 2, 1]
Stages: 4
Kernel schedule: ScheduleAuto
}
// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8
using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;128, 128, 32&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 32&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, float, float&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
4,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type :
public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc58de20&gt;
</pre></div></div>
</div>
<p>One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Stream K is only supported pre-SM90 (at least when this example was written)</span>
<span class="k">if</span> <span class="n">plan</span><span class="o">.</span><span class="n">cc</span> <span class="o">!=</span> <span class="mi">90</span><span class="p">:</span>
<span class="n">plan</span><span class="o">.</span><span class="n">swizzling_functor</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">swizzle</span><span class="o">.</span><span class="n">ThreadblockSwizzleStreamK</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8
using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;128, 128, 32&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 32&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, float, float&gt;,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
4,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type :
public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };
</pre></div></div>
</div>
</section>
<section id="Handling-errors">
<h2>Handling errors<a class="headerlink" href="#Handling-errors" title="Permalink to this heading">#</a></h2>
<p>The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.</p>
<p>Heres an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># td = tiles[0]</span>
<span class="c1"># td.stages = 8</span>
<span class="c1"># plan.compile(td)</span>
</pre></div>
</div>
</div>
</section>
</section>
</article>
</div>
<footer>
<div class="related-pages">
<a class="next-page" href="../contribute.html">
<div class="page-info">
<div class="context">
<span>Next</span>
</div>
<div class="title">Contributing</div>
</div>
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
</a>
<a class="prev-page" href="../install.html">
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
<div class="page-info">
<div class="context">
<span>Previous</span>
</div>
<div class="title">Installation</div>
</div>
</a>
</div>
<div class="bottom-of-page">
<div class="left-details">
<div class="copyright">
Copyright &#169; 2023, NVIDIA
</div>
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
<a href="https://github.com/pradyunsg/furo">Furo</a>
</div>
<div class="right-details">
<div class="icons">
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<aside class="toc-drawer">
<div class="toc-sticky toc-scroll">
<div class="toc-title-container">
<span class="toc-title">
On this page
</span>
</div>
<div class="toc-tree-container">
<div class="toc-tree">
<ul>
<li><a class="reference internal" href="#">Basic example of using the CUTLASS Python interface</a><ul>
<li><a class="reference internal" href="#Declaring-and-running-a-GEMM">Declaring and running a GEMM</a></li>
<li><a class="reference internal" href="#Changing-operation-modes">Changing operation modes</a></li>
<li><a class="reference internal" href="#Running-cached-kernels">Running cached kernels</a></li>
<li><a class="reference internal" href="#Running-non-default-GEMMs">Running non-default GEMMs</a></li>
<li><a class="reference internal" href="#Handling-errors">Handling errors</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</aside>
</div>
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/sphinx_highlight.js"></script>
<script src="../_static/scripts/furo.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script src="../_static/tabs.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>

View File

@ -0,0 +1,727 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1ef96b3f",
"metadata": {},
"source": [
"# Basic example of using the CUTLASS Python interface\n",
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n"
]
},
{
"cell_type": "markdown",
"id": "962324fd",
"metadata": {},
"source": [
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e324219",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:39.749457Z",
"iopub.status.busy": "2023-04-18T17:59:39.748884Z",
"iopub.status.idle": "2023-04-18T17:59:43.907956Z",
"shell.execute_reply": "2023-04-18T17:59:43.907069Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import random\n",
"\n",
"import cutlass\n",
"\n",
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
"# omit this information.\n",
"print_module = True\n",
"\n",
"m = 128\n",
"n = m\n",
"k = m\n",
"\n",
"dtype = np.float16\n",
"type_A = np.float16\n",
"type_B = np.float16\n",
"type_C = np.float16\n",
"type_D = np.float16\n",
"\n",
"np.random.seed(1234)\n",
"random.seed(1234)\n",
"scope_min = -4\n",
"scope_max = 4\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(0.)\n",
"\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
]
},
{
"cell_type": "markdown",
"id": "f2c7bf48",
"metadata": {},
"source": [
"## Declaring and running a GEMM\n",
"To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n",
"This sets up a default GEMM operation for the given device on which you are running.\n",
"\n",
"Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n",
"\n",
"Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0dfd8975",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:43.911740Z",
"iopub.status.busy": "2023-04-18T17:59:43.911512Z",
"iopub.status.idle": "2023-04-18T17:59:49.103941Z",
"shell.execute_reply": "2023-04-18T17:59:49.103231Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc556070>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n",
"# specifying `element_accumulator` is not required if it is the same as `element`\n",
"plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "4a5856de",
"metadata": {},
"source": [
"There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor."
]
},
{
"cell_type": "markdown",
"id": "945478ef",
"metadata": {},
"source": [
"We then compare the output to running the GEMM using NumPy."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6b669de6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.107492Z",
"iopub.status.busy": "2023-04-18T17:59:49.107284Z",
"iopub.status.idle": "2023-04-18T17:59:49.138511Z",
"shell.execute_reply": "2023-04-18T17:59:49.137837Z"
}
},
"outputs": [],
"source": [
"tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n",
"np.testing.assert_array_equal(tensor_D, tensor_D_numpy)"
]
},
{
"cell_type": "markdown",
"id": "ee5cbbbe",
"metadata": {},
"source": [
"Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy."
]
},
{
"cell_type": "markdown",
"id": "b6c86493",
"metadata": {},
"source": [
"## Changing operation modes\n",
"By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n",
"\n",
"The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "529fda93",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.141458Z",
"iopub.status.busy": "2023-04-18T17:59:49.141305Z",
"iopub.status.idle": "2023-04-18T17:59:49.145005Z",
"shell.execute_reply": "2023-04-18T17:59:49.144332Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpcodeClass.TensorOp\n"
]
}
],
"source": [
"print(plan.opclass)"
]
},
{
"cell_type": "markdown",
"id": "6d27c575",
"metadata": {},
"source": [
"Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n",
"\n",
"As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n",
"\n",
"Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6a44d35b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.148548Z",
"iopub.status.busy": "2023-04-18T17:59:49.148042Z",
"iopub.status.idle": "2023-04-18T17:59:54.365792Z",
"shell.execute_reply": "2023-04-18T17:59:54.364734Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1\n",
"using cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassSimt,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 8>,\n",
" cutlass::gemm::GemmShape<32, 64, 8>,\n",
" cutlass::gemm::GemmShape<1, 1, 1>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 2,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_type : \n",
" public cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b3075abe0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n",
"plan.opclass = cutlass.OpcodeClass.Simt\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "639dcb59",
"metadata": {},
"source": [
"If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9b480853",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:54.369977Z",
"iopub.status.busy": "2023-04-18T17:59:54.369302Z",
"iopub.status.idle": "2023-04-18T17:59:54.375239Z",
"shell.execute_reply": "2023-04-18T17:59:54.374405Z"
}
},
"outputs": [],
"source": [
"np.testing.assert_array_equal(tensor_D, tensor_D_simt)"
]
},
{
"cell_type": "markdown",
"id": "0cce1eae",
"metadata": {},
"source": [
"## Running cached kernels\n",
"You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n",
"\n",
"CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f8051e5e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:54.378373Z",
"iopub.status.busy": "2023-04-18T17:59:54.378060Z",
"iopub.status.idle": "2023-04-18T17:59:55.220086Z",
"shell.execute_reply": "2023-04-18T17:59:55.219198Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b30fb9880>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = 2400\n",
"n = 3232\n",
"k = 4096\n",
"\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(2.)\n",
"\n",
"plan.opclass = cutlass.OpcodeClass.TensorOp\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "52a4e318",
"metadata": {},
"source": [
"## Running non-default GEMMs\n",
"The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n",
"\n",
"Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c593be1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:55.223812Z",
"iopub.status.busy": "2023-04-18T17:59:55.223651Z",
"iopub.status.idle": "2023-04-18T17:59:55.228769Z",
"shell.execute_reply": "2023-04-18T17:59:55.228101Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"132 tile descriptions returned\n",
"First 10 tile descriptions are:\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 64]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 64]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 64]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 64]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 32]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 32]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 64, 64]\n",
" WarpCount: [4, 1, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [64, 256, 64]\n",
" WarpCount: [1, 4, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 128, 64]\n",
" WarpCount: [2, 2, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 64, 64]\n",
" WarpCount: [4, 1, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n"
]
}
],
"source": [
"tiles = plan.tile_descriptions()\n",
"print('{} tile descriptions returned'.format(len(tiles)))\n",
"num_print = 10\n",
"print('First {} tile descriptions are:'.format(num_print))\n",
"for td in tiles[:num_print]:\n",
" print(td)"
]
},
{
"cell_type": "markdown",
"id": "dc3ad875",
"metadata": {},
"source": [
"Next, we'll pick one of these configurations at random and compile and run it."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a8dc5287",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:55.231498Z",
"iopub.status.busy": "2023-04-18T17:59:55.230924Z",
"iopub.status.idle": "2023-04-18T18:00:00.340161Z",
"shell.execute_reply": "2023-04-18T18:00:00.339603Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tile description 112 is: \n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 128, 32]\n",
" WarpCount: [2, 2, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 32>,\n",
" cutlass::gemm::GemmShape<64, 64, 32>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 4,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc58de20>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = random.randint(0, len(tiles)-1)\n",
"td = tiles[idx]\n",
"print('Tile description {} is: {}'.format(idx, td))\n",
"plan.compile(td)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "c5a8b534",
"metadata": {},
"source": [
"One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e5e88d17",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:00.343772Z",
"iopub.status.busy": "2023-04-18T18:00:00.343582Z",
"iopub.status.idle": "2023-04-18T18:00:06.192256Z",
"shell.execute_reply": "2023-04-18T18:00:06.191286Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 32>,\n",
" cutlass::gemm::GemmShape<64, 64, 32>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,\n",
" 4,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };\n",
"\n"
]
}
],
"source": [
"# Stream K is only supported pre-SM90 (at least when this example was written)\n",
"if plan.cc != 90:\n",
" plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "5a8ba2ba",
"metadata": {},
"source": [
"## Handling errors\n",
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
"\n",
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe7d0e42",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:06.196345Z",
"iopub.status.busy": "2023-04-18T18:00:06.195784Z",
"iopub.status.idle": "2023-04-18T18:00:06.199248Z",
"shell.execute_reply": "2023-04-18T18:00:06.198438Z"
}
},
"outputs": [],
"source": [
"# td = tiles[0]\n",
"# td.stages = 8\n",
"# plan.compile(td)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

764
python/docs/externals/01_epilogue.html vendored Normal file
View File

@ -0,0 +1,764 @@
<!doctype html>
<html class="no-js" lang="en">
<head><meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="next" title="Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension" href="02_pytorch_extension_grouped_gemm.html" /><link rel="prev" title="Examples" href="../examples.html" />
<link rel="canonical" href="docs/externals/01_epilogue.html" />
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
<title>Example of using elementwise activation functions in the CUTLASS Python interface - CUTLASS Python</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
<style>
body {
--color-code-background: #eeffcc;
--color-code-foreground: black;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media not print {
body[data-theme="dark"] {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media (prefers-color-scheme: dark) {
body:not([data-theme="light"]) {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
}
}
</style></head>
<body>
<script>
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
</script>
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
<symbol id="svg-toc" viewBox="0 0 24 24">
<title>Contents</title>
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
</svg>
</symbol>
<symbol id="svg-menu" viewBox="0 0 24 24">
<title>Menu</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
<line x1="3" y1="12" x2="21" y2="12"></line>
<line x1="3" y1="6" x2="21" y2="6"></line>
<line x1="3" y1="18" x2="21" y2="18"></line>
</svg>
</symbol>
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
<title>Expand</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
<polyline points="9 18 15 12 9 6"></polyline>
</svg>
</symbol>
<symbol id="svg-sun" viewBox="0 0 24 24">
<title>Light mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
<circle cx="12" cy="12" r="5"></circle>
<line x1="12" y1="1" x2="12" y2="3"></line>
<line x1="12" y1="21" x2="12" y2="23"></line>
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
<line x1="1" y1="12" x2="3" y2="12"></line>
<line x1="21" y1="12" x2="23" y2="12"></line>
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
</svg>
</symbol>
<symbol id="svg-moon" viewBox="0 0 24 24">
<title>Dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
</svg>
</symbol>
<symbol id="svg-sun-half" viewBox="0 0 24 24">
<title>Auto light/dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<circle cx="12" cy="12" r="9" />
<path d="M13 12h5" />
<path d="M13 15h4" />
<path d="M13 18h1" />
<path d="M13 9h4" />
<path d="M13 6h1" />
</svg>
</symbol>
</svg>
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
<label class="overlay sidebar-overlay" for="__navigation">
<div class="visually-hidden">Hide navigation sidebar</div>
</label>
<label class="overlay toc-overlay" for="__toc">
<div class="visually-hidden">Hide table of contents sidebar</div>
</label>
<div class="page">
<header class="mobile-header">
<div class="header-left">
<label class="nav-overlay-icon" for="__navigation">
<div class="visually-hidden">Toggle site navigation sidebar</div>
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
</label>
</div>
<div class="header-center">
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
</div>
<div class="header-right">
<div class="theme-toggle-container theme-toggle-header">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-header-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
</header>
<aside class="sidebar-drawer">
<div class="sidebar-container">
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
<div class="sidebar-logo-container">
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
</div>
<span class="sidebar-brand-text">CUTLASS Python</span>
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
<input type="hidden" name="check_keywords" value="yes">
<input type="hidden" name="area" value="default">
</form>
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="00_basic_gemm.html">Getting Started</a></li>
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
<ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
<ul class="current">
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="00_basic_gemm.html">Basic GEMM</a></li>
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">Epilogue</a></li>
<li class="toctree-l2"><a class="reference internal" href="02_pytorch_extension_grouped_gemm.html">PyTorch Extension</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
</ul>
</div>
</div>
</div>
</div>
</aside>
<div class="main">
<div class="content">
<div class="article-container">
<a href="#" class="back-to-top muted-link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
</svg>
<span>Back to top</span>
</a>
<div class="content-icon-container">
<div class="theme-toggle-container theme-toggle-content">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-content-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
<article role="main">
<section id="Example-of-using-elementwise-activation-functions-in-the-CUTLASS-Python-interface">
<h1>Example of using elementwise activation functions in the CUTLASS Python interface<a class="headerlink" href="#Example-of-using-elementwise-activation-functions-in-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h1>
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.</p>
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
<p>We first import various packages needed for the example and construct the input and output tensors that will be used in our example.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">cutlass</span>
<span class="c1"># This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to</span>
<span class="c1"># omit this information.</span>
<span class="n">print_module</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">m</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">m</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">m</span>
<span class="n">type_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">type_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">1234</span><span class="p">)</span>
<span class="n">scope_min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">4</span>
<span class="n">scope_max</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">tensor_A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_A</span><span class="p">))</span>
<span class="n">tensor_B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_B</span><span class="p">))</span>
<span class="n">tensor_C</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="n">scope_min</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scope_max</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_C</span><span class="p">))</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">beta</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">(</span><span class="mf">0.</span><span class="p">)</span>
<span class="n">tensor_D</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area stderr docutils container">
<div class="highlight"><pre>
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
</pre></div></div>
</div>
<section id="Run-a-GEMM-with-an-identity-activation-function">
<h2>Run a GEMM with an identity activation function<a class="headerlink" href="#Run-a-GEMM-with-an-identity-activation-function" title="Permalink to this heading">#</a></h2>
<p>To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation <code class="docutils literal notranslate"><span class="pre">D</span> <span class="pre">=</span> <span class="pre">alpha</span> <span class="pre">*</span> <span class="pre">(A</span> <span class="pre">&#64;</span> <span class="pre">B)</span> <span class="pre">+</span> <span class="pre">beta</span> <span class="pre">*</span> <span class="pre">C</span></code>. This is the default activation function used, and does not need to be specified.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">Gemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0&gt;
</pre></div></div>
</div>
</section>
<section id="Run-a-GEMM-with-a-ReLU-element-wise-activation-function">
<h2>Run a GEMM with a ReLU element-wise activation function<a class="headerlink" href="#Run-a-GEMM-with-a-ReLU-element-wise-activation-function" title="Permalink to this heading">#</a></h2>
<p>CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function <code class="docutils literal notranslate"><span class="pre">act</span></code>, the resulting formulation is:</p>
<div class="highlight-none notranslate"><div class="highlight"><pre><span></span>D = alpha * (A @ B) + beta * C
D = act(D)
</pre></div>
</div>
<p>Here, we will add a ReLU activation function. Given an input <code class="docutils literal notranslate"><span class="pre">x</span></code>, ReLU returns <code class="docutils literal notranslate"><span class="pre">max(x,</span> <span class="pre">0)</span></code>.</p>
<p>This is easy to do in CUTLASS. One only needs to set the plans <code class="docutils literal notranslate"><span class="pre">activation</span></code> field.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">tensor_D_relu</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor_C</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">epilogue</span><span class="o">.</span><span class="n">relu</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D_relu</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460&gt;
</pre></div></div>
</div>
<p>We can now verify that the result of the GEMM that used a ReLU activation function:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">relu_ref</span> <span class="o">=</span> <span class="p">(</span><span class="n">tensor_D</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">type_D</span><span class="p">)</span> <span class="o">*</span> <span class="n">tensor_D</span>
<span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_array_equal</span><span class="p">(</span><span class="n">relu_ref</span><span class="p">,</span> <span class="n">tensor_D_relu</span><span class="p">)</span>
</pre></div>
</div>
</div>
</section>
<section id="Other-element-wise-activation-functions">
<h2>Other element-wise activation functions<a class="headerlink" href="#Other-element-wise-activation-functions" title="Permalink to this heading">#</a></h2>
<p>CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the <code class="docutils literal notranslate"><span class="pre">get_activations()</span></code> method.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">activations</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">activations</span><span class="p">()</span>
<span class="k">for</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">activations</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="n">activation</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
&lt;class &#39;cutlass.backend.epilogue.gelu&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.hardswish&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.identity&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.leaky_relu&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.relu&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.sigmoid&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.silu&#39;&gt;
&lt;class &#39;cutlass.backend.epilogue.tanh&#39;&gt;
</pre></div></div>
</div>
<p>We can then run each of them:</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">activations</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;=============================================================================================&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Compiling and running activation </span><span class="si">{</span><span class="n">activation</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;=============================================================================================&#39;</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">tensor_A</span><span class="p">,</span> <span class="n">tensor_B</span><span class="p">,</span> <span class="n">tensor_C</span><span class="p">,</span> <span class="n">tensor_D</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="n">print_module</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.gelu&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.hardswish&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.identity&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.leaky_relu&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.relu&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.sigmoid&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.silu&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation &lt;class &#39;cutlass.backend.epilogue.tanh&#39;&gt;
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombinationGeneric&lt;cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[ ]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>
</pre></div>
</div>
</div>
</section>
</section>
</article>
</div>
<footer>
<div class="related-pages">
<a class="next-page" href="02_pytorch_extension_grouped_gemm.html">
<div class="page-info">
<div class="context">
<span>Next</span>
</div>
<div class="title">Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension</div>
</div>
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
</a>
<a class="prev-page" href="../examples.html">
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
<div class="page-info">
<div class="context">
<span>Previous</span>
</div>
<div class="title">Examples</div>
</div>
</a>
</div>
<div class="bottom-of-page">
<div class="left-details">
<div class="copyright">
Copyright &#169; 2023, NVIDIA
</div>
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
<a href="https://github.com/pradyunsg/furo">Furo</a>
</div>
<div class="right-details">
<div class="icons">
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<aside class="toc-drawer">
<div class="toc-sticky toc-scroll">
<div class="toc-title-container">
<span class="toc-title">
On this page
</span>
</div>
<div class="toc-tree-container">
<div class="toc-tree">
<ul>
<li><a class="reference internal" href="#">Example of using elementwise activation functions in the CUTLASS Python interface</a><ul>
<li><a class="reference internal" href="#Run-a-GEMM-with-an-identity-activation-function">Run a GEMM with an identity activation function</a></li>
<li><a class="reference internal" href="#Run-a-GEMM-with-a-ReLU-element-wise-activation-function">Run a GEMM with a ReLU element-wise activation function</a></li>
<li><a class="reference internal" href="#Other-element-wise-activation-functions">Other element-wise activation functions</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</aside>
</div>
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/sphinx_highlight.js"></script>
<script src="../_static/scripts/furo.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script src="../_static/tabs.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>

593
python/docs/externals/01_epilogue.ipynb vendored Normal file
View File

@ -0,0 +1,593 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "5d24a692",
"metadata": {},
"source": [
"# Example of using elementwise activation functions in the CUTLASS Python interface\n",
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "3ca993fe",
"metadata": {},
"source": [
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "63a70a3c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:09.148380Z",
"iopub.status.busy": "2023-04-18T18:00:09.148011Z",
"iopub.status.idle": "2023-04-18T18:00:13.281937Z",
"shell.execute_reply": "2023-04-18T18:00:13.281256Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import cutlass\n",
"\n",
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
"# omit this information.\n",
"print_module = True\n",
"\n",
"m = 256\n",
"n = m\n",
"k = m\n",
"\n",
"type_A = np.float16\n",
"type_B = np.float16\n",
"type_C = np.float16\n",
"type_D = np.float16\n",
"\n",
"np.random.seed(1234)\n",
"scope_min = -4\n",
"scope_max = 4\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(0.)\n",
"\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
]
},
{
"cell_type": "markdown",
"id": "1eb0d95b",
"metadata": {},
"source": [
"## Run a GEMM with an identity activation function\n",
"To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d257833",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:13.284650Z",
"iopub.status.busy": "2023-04-18T18:00:13.284425Z",
"iopub.status.idle": "2023-04-18T18:00:18.333867Z",
"shell.execute_reply": "2023-04-18T18:00:18.333187Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "54961694",
"metadata": {},
"source": [
"## Run a GEMM with a ReLU element-wise activation function\n",
"CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n",
"```\n",
"D = alpha * (A @ B) + beta * C\n",
"D = act(D)\n",
"```\n",
"\n",
"Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n",
"\n",
"This is easy to do in CUTLASS. One only needs to set the plan's `activation` field."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5fe49443",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:18.337036Z",
"iopub.status.busy": "2023-04-18T18:00:18.336833Z",
"iopub.status.idle": "2023-04-18T18:00:23.482072Z",
"shell.execute_reply": "2023-04-18T18:00:23.481125Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n",
"plan.activation = cutlass.epilogue.relu\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "455d0a37",
"metadata": {},
"source": [
"We can now verify that the result of the GEMM that used a ReLU activation function:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e32e7798",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.486042Z",
"iopub.status.busy": "2023-04-18T18:00:23.485342Z",
"iopub.status.idle": "2023-04-18T18:00:23.497444Z",
"shell.execute_reply": "2023-04-18T18:00:23.496668Z"
}
},
"outputs": [],
"source": [
"relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n",
"np.testing.assert_array_equal(relu_ref, tensor_D_relu)"
]
},
{
"cell_type": "markdown",
"id": "cf959171",
"metadata": {},
"source": [
"## Other element-wise activation functions\n",
"CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9e17d730",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.500102Z",
"iopub.status.busy": "2023-04-18T18:00:23.499944Z",
"iopub.status.idle": "2023-04-18T18:00:23.504562Z",
"shell.execute_reply": "2023-04-18T18:00:23.503793Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'cutlass.backend.epilogue.gelu'>\n",
"<class 'cutlass.backend.epilogue.hardswish'>\n",
"<class 'cutlass.backend.epilogue.identity'>\n",
"<class 'cutlass.backend.epilogue.leaky_relu'>\n",
"<class 'cutlass.backend.epilogue.relu'>\n",
"<class 'cutlass.backend.epilogue.sigmoid'>\n",
"<class 'cutlass.backend.epilogue.silu'>\n",
"<class 'cutlass.backend.epilogue.tanh'>\n"
]
}
],
"source": [
"activations = plan.activations()\n",
"for activation in activations:\n",
" print(activation)"
]
},
{
"cell_type": "markdown",
"id": "0e4599fa",
"metadata": {},
"source": [
"We can then run each of them:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9c3598c9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.507538Z",
"iopub.status.busy": "2023-04-18T18:00:23.507257Z",
"iopub.status.idle": "2023-04-18T18:00:59.414765Z",
"shell.execute_reply": "2023-04-18T18:00:59.414116Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.gelu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.identity'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n",
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.relu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n",
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.silu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.tanh'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
}
],
"source": [
"for activation in activations:\n",
" print('=============================================================================================')\n",
" print(f'Compiling and running activation {activation}')\n",
" print('=============================================================================================')\n",
" plan.activation = activation\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "751f8d92",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,537 @@
<!doctype html>
<html class="no-js" lang="en">
<head><meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="prev" title="Example of using elementwise activation functions in the CUTLASS Python interface" href="01_epilogue.html" />
<link rel="canonical" href="docs/externals/02_pytorch_extension_grouped_gemm.html" />
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
<title>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension - CUTLASS Python</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
<style>
body {
--color-code-background: #eeffcc;
--color-code-foreground: black;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media not print {
body[data-theme="dark"] {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media (prefers-color-scheme: dark) {
body:not([data-theme="light"]) {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
}
}
</style></head>
<body>
<script>
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
</script>
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
<symbol id="svg-toc" viewBox="0 0 24 24">
<title>Contents</title>
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
</svg>
</symbol>
<symbol id="svg-menu" viewBox="0 0 24 24">
<title>Menu</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
<line x1="3" y1="12" x2="21" y2="12"></line>
<line x1="3" y1="6" x2="21" y2="6"></line>
<line x1="3" y1="18" x2="21" y2="18"></line>
</svg>
</symbol>
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
<title>Expand</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
<polyline points="9 18 15 12 9 6"></polyline>
</svg>
</symbol>
<symbol id="svg-sun" viewBox="0 0 24 24">
<title>Light mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
<circle cx="12" cy="12" r="5"></circle>
<line x1="12" y1="1" x2="12" y2="3"></line>
<line x1="12" y1="21" x2="12" y2="23"></line>
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
<line x1="1" y1="12" x2="3" y2="12"></line>
<line x1="21" y1="12" x2="23" y2="12"></line>
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
</svg>
</symbol>
<symbol id="svg-moon" viewBox="0 0 24 24">
<title>Dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
</svg>
</symbol>
<symbol id="svg-sun-half" viewBox="0 0 24 24">
<title>Auto light/dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<circle cx="12" cy="12" r="9" />
<path d="M13 12h5" />
<path d="M13 15h4" />
<path d="M13 18h1" />
<path d="M13 9h4" />
<path d="M13 6h1" />
</svg>
</symbol>
</svg>
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
<label class="overlay sidebar-overlay" for="__navigation">
<div class="visually-hidden">Hide navigation sidebar</div>
</label>
<label class="overlay toc-overlay" for="__toc">
<div class="visually-hidden">Hide table of contents sidebar</div>
</label>
<div class="page">
<header class="mobile-header">
<div class="header-left">
<label class="nav-overlay-icon" for="__navigation">
<div class="visually-hidden">Toggle site navigation sidebar</div>
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
</label>
</div>
<div class="header-center">
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
</div>
<div class="header-right">
<div class="theme-toggle-container theme-toggle-header">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-header-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
</header>
<aside class="sidebar-drawer">
<div class="sidebar-container">
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
<div class="sidebar-logo-container">
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
</div>
<span class="sidebar-brand-text">CUTLASS Python</span>
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
<input type="hidden" name="check_keywords" value="yes">
<input type="hidden" name="area" value="default">
</form>
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="00_basic_gemm.html">Getting Started</a></li>
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
<ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
<ul class="current">
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="00_basic_gemm.html">Basic GEMM</a></li>
<li class="toctree-l2"><a class="reference internal" href="01_epilogue.html">Epilogue</a></li>
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">PyTorch Extension</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
</ul>
</div>
</div>
</div>
</div>
</aside>
<div class="main">
<div class="content">
<div class="article-container">
<a href="#" class="back-to-top muted-link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
</svg>
<span>Back to top</span>
</a>
<div class="content-icon-container">
<div class="theme-toggle-container theme-toggle-content">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-content-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
<article role="main">
<section id="Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension">
<h1>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h1>
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare a grouped GEMM kernel and export it as a PyTorch CUDA extension.</p>
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
<section id="Background-on-grouped-GEMM">
<h2>Background on grouped GEMM<a class="headerlink" href="#Background-on-grouped-GEMM" title="Permalink to this heading">#</a></h2>
<p>Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides) in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM, without the requirement that the sizes and strides of each GEMM be the same.</p>
<p>For example, if one has <code class="docutils literal notranslate"><span class="pre">p</span></code> GEMMs with sizes:</p>
<div class="highlight-text notranslate"><div class="highlight"><pre><span></span>M_1 x N_1 x K_1
M_2 x N_2 x K_2
...
M_p x N_p x K_p
</pre></div>
</div>
<p>CUTLASSs grouped GEMM will execute these in a single CUDA kernel.</p>
<p>Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would insufficiently utilize the device in isolation.</p>
</section>
<section id="Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">
<h2>Declaring a grouped GEMM via the CUTLASS Python interface<a class="headerlink" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h2>
<p>A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one simply calls <code class="docutils literal notranslate"><span class="pre">cutlass.op.GroupedGemm</span></code>.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">cutlass</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
<span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">GroupedGemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area stderr docutils container">
<div class="highlight"><pre>
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
</pre></div></div>
</div>
<p>We can then compile and run this operation on a group of GEMMs. Well first set up some utility functions to initialize GEMMs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">random</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">2023</span><span class="p">)</span>
<span class="c1"># Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K</span>
<span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">):</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">]</span>
<span class="c1"># Utility function to generate `problems` GEMMs of random sizes</span>
<span class="k">def</span> <span class="nf">generate_problems</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
<span class="n">valid_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span>
<span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="p">[</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">valid_sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)]</span>
<span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">As</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
<span class="n">Bs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
<span class="n">Cs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">C</span><span class="p">)</span>
<span class="n">Ds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">D</span><span class="p">)</span>
<span class="k">return</span> <span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span>
</pre></div>
</div>
</div>
<p>Well next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="o">=</span> <span class="n">generate_problems</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmGrouped&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
</section>
<section id="Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">
<h2>Exporting the CUTLASS kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h2>
<p>The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">PyTorch CUDA extension</a>. This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.</p>
<p>The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later “ahead-of-time” compilation, or be just-in-time compiled and returned to the user.</p>
<p>To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">op</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">construct</span><span class="p">()</span>
<span class="n">grouped_gemm</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">emit</span><span class="o">.</span><span class="n">pytorch</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;grouped_gemm&#39;</span><span class="p">,</span> <span class="n">cc</span><span class="o">=</span><span class="n">plan</span><span class="o">.</span><span class="n">cc</span><span class="p">,</span> <span class="n">sourcedir</span><span class="o">=</span><span class="s1">&#39;out&#39;</span><span class="p">,</span> <span class="n">jit</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">cutlass.emit.pytorch</span></code> function emits: * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm_kernel.cu</span></code>: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm.cpp</span></code>: This file contains a C++ wrapper around the aforementioned CUTLASS kernel * <code class="docutils literal notranslate"><span class="pre">setup.py</span></code>: This file contains the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> script for building and installing the generated extension</p>
<p>The extension can be build from within the <code class="docutils literal notranslate"><span class="pre">module_output</span></code> directory by running:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">&quot;8.0&quot;</span><span class="w"> </span>python<span class="w"> </span>setup.py<span class="w"> </span>install
</pre></div>
</div>
<p>Where <code class="docutils literal notranslate"><span class="pre">TORCH_ARCH_LIST</span></code> is set to the compute capability of the device on which the kernel will be run.</p>
<p>See the PyTorch <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">“Custom C++ and CUDA Extensions”</a> tutorial for more details on this.</p>
<p>The PyTorch CUDA extension could be built for this module by running:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span><span class="w"> </span>out
<span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">&quot;8.0&quot;</span><span class="w"> </span>python<span class="w"> </span>setup.py
</pre></div>
</div>
<p>(assuming that one is building for SM80)</p>
<p>One could then use the kernel in a later PyTorch module by running:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">grouped_gemm</span>
<span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
</pre></div>
</div>
<p>In this case, however, we set <code class="docutils literal notranslate"><span class="pre">jit=True</span></code>, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly. Under the hood, this leverages the <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">torch.utils.cpp_extension.load</a> method and returns back the loaded extension.</p>
<p>We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>Finally, we can profile our grouped GEMM extension:</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">num_warmup</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">num_profile</span> <span class="o">=</span> <span class="mi">100</span>
<span class="c1"># Warmup iterations</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_warmup</span><span class="p">):</span>
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="c1"># Timing iterations</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="n">grouped</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">nongrouped</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_profile</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">grouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">nongrouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">grouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Non-Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Speedup: </span><span class="si">{:.3f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">/</span> <span class="n">grouped</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
Grouped: 400.696 us
Non-Grouped: 646.670 us
Speedup: 1.614
</pre></div></div>
</div>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[ ]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>
</pre></div>
</div>
</div>
</section>
</section>
</article>
</div>
<footer>
<div class="related-pages">
<a class="prev-page" href="01_epilogue.html">
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
<div class="page-info">
<div class="context">
<span>Previous</span>
</div>
<div class="title">Example of using elementwise activation functions in the CUTLASS Python interface</div>
</div>
</a>
</div>
<div class="bottom-of-page">
<div class="left-details">
<div class="copyright">
Copyright &#169; 2023, NVIDIA
</div>
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
<a href="https://github.com/pradyunsg/furo">Furo</a>
</div>
<div class="right-details">
<div class="icons">
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<aside class="toc-drawer">
<div class="toc-sticky toc-scroll">
<div class="toc-title-container">
<span class="toc-title">
On this page
</span>
</div>
<div class="toc-tree-container">
<div class="toc-tree">
<ul>
<li><a class="reference internal" href="#">Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension</a><ul>
<li><a class="reference internal" href="#Background-on-grouped-GEMM">Background on grouped GEMM</a></li>
<li><a class="reference internal" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">Declaring a grouped GEMM via the CUTLASS Python interface</a></li>
<li><a class="reference internal" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">Exporting the CUTLASS kernel to a PyTorch CUDA extension</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</aside>
</div>
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/sphinx_highlight.js"></script>
<script src="../_static/scripts/furo.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script src="../_static/tabs.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>

View File

@ -0,0 +1,356 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "6acbea5d",
"metadata": {},
"source": [
"# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n",
"This notebook walks through a basic example of using the CUTLASS Python interface to declare\n",
"a grouped GEMM kernel and export it as a PyTorch CUDA extension.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n",
"\n",
"## Background on grouped GEMM\n",
"Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n",
"in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n",
"without the requirement that the sizes and strides of each GEMM be the same.\n",
"\n",
"For example, if one has `p` GEMMs with sizes:\n",
"```text\n",
"M_1 x N_1 x K_1\n",
"M_2 x N_2 x K_2\n",
"...\n",
"M_p x N_p x K_p\n",
"```\n",
"CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n",
"\n",
"Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n",
"insufficiently utilize the device in isolation.\n",
"\n",
"## Declaring a grouped GEMM via the CUTLASS Python interface\n",
"A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n",
"simply calls `cutlass.op.GroupedGemm`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "fdcf21d8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:01:01.888030Z",
"iopub.status.busy": "2023-04-18T18:01:01.887634Z",
"iopub.status.idle": "2023-04-18T18:01:06.069481Z",
"shell.execute_reply": "2023-04-18T18:01:06.068513Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import cutlass\n",
"import torch\n",
"\n",
"dtype = torch.float16\n",
"plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)"
]
},
{
"cell_type": "markdown",
"id": "514f40a4",
"metadata": {},
"source": [
"We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c2a7371e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:01:06.073326Z",
"iopub.status.busy": "2023-04-18T18:01:06.073092Z",
"iopub.status.idle": "2023-04-18T18:01:06.080337Z",
"shell.execute_reply": "2023-04-18T18:01:06.079517Z"
}
},
"outputs": [],
"source": [
"import random\n",
"random.seed(2023)\n",
"\n",
"# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n",
"def initialize(dtype, M, N, K):\n",
" sizes = [(M, K), (K, N), (M, N), (M, N)]\n",
" return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n",
"\n",
"# Utility function to generate `problems` GEMMs of random sizes\n",
"def generate_problems(problems):\n",
" valid_sizes = [128, 256, 512, 1024]\n",
" As, Bs, Cs, Ds = [], [], [], []\n",
" for _ in range(problems):\n",
" M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n",
" A, B, C, D = initialize(dtype, M, N, K)\n",
" As.append(A)\n",
" Bs.append(B)\n",
" Cs.append(C)\n",
" Ds.append(D)\n",
" return As, Bs, Cs, Ds"
]
},
{
"cell_type": "markdown",
"id": "590a3bc5",
"metadata": {},
"source": [
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "776c9233",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:01:06.083288Z",
"iopub.status.busy": "2023-04-18T18:01:06.083082Z",
"iopub.status.idle": "2023-04-18T18:01:10.783577Z",
"shell.execute_reply": "2023-04-18T18:01:10.782798Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmGrouped<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_type :\n",
" public cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
}
],
"source": [
"As, Bs, Cs, Ds, = generate_problems(50)\n",
"\n",
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
"\n",
"for d, d_torch in zip(Ds, Ds_torch):\n",
" assert torch.allclose(d, d_torch)"
]
},
{
"cell_type": "markdown",
"id": "766e4f03",
"metadata": {},
"source": [
"## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
"The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
"\n",
"The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
"\n",
"To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3a98dee6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:01:10.787020Z",
"iopub.status.busy": "2023-04-18T18:01:10.786862Z",
"iopub.status.idle": "2023-04-18T18:02:08.445210Z",
"shell.execute_reply": "2023-04-18T18:02:08.443997Z"
}
},
"outputs": [],
"source": [
"op = plan.construct()\n",
"grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)"
]
},
{
"cell_type": "markdown",
"id": "c8ca3991",
"metadata": {},
"source": [
"The `cutlass.emit.pytorch` function emits:\n",
"* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n",
"* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n",
"* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n",
"\n",
"The extension can be build from within the `module_output` directory by running:\n",
"```bash\n",
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n",
"```\n",
"Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
"\n",
"See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
"\n",
"The PyTorch CUDA extension could be built for this module by running:\n",
"```bash\n",
"cd out\n",
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
"```\n",
"(assuming that one is building for SM80)\n",
"\n",
"One could then use the kernel in a later PyTorch module by running:\n",
"\n",
"```python\n",
"import torch\n",
"import grouped_gemm\n",
"\n",
"grouped_gemm.run(As, Bs)\n",
"```\n",
"\n",
"In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
"Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
"and returns back the loaded extension.\n",
"\n",
"We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cecb26a4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:02:08.449530Z",
"iopub.status.busy": "2023-04-18T18:02:08.449077Z",
"iopub.status.idle": "2023-04-18T18:02:08.464755Z",
"shell.execute_reply": "2023-04-18T18:02:08.464200Z"
}
},
"outputs": [],
"source": [
"Ds = grouped_gemm.run(As, Bs)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
"for d, d_torch in zip(Ds, Ds_torch):\n",
" assert torch.allclose(d, d_torch)"
]
},
{
"cell_type": "markdown",
"id": "50db80e4",
"metadata": {},
"source": [
"Finally, we can profile our grouped GEMM extension:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b76805d3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:02:08.467087Z",
"iopub.status.busy": "2023-04-18T18:02:08.466879Z",
"iopub.status.idle": "2023-04-18T18:02:08.603689Z",
"shell.execute_reply": "2023-04-18T18:02:08.603085Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Grouped: 400.696 us\n",
"Non-Grouped: 646.670 us\n",
"Speedup: 1.614\n"
]
}
],
"source": [
"num_warmup = 20\n",
"num_profile = 100\n",
"\n",
"# Warmup iterations\n",
"for _ in range(num_warmup):\n",
" Ds = grouped_gemm.run(As, Bs)\n",
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
" torch.cuda.synchronize()\n",
"\n",
"# Timing iterations\n",
"import time\n",
"grouped = 0\n",
"nongrouped = 0\n",
"for _ in range(num_profile):\n",
" start = time.time()\n",
" Ds = grouped_gemm.run(As, Bs)\n",
" torch.cuda.synchronize()\n",
" grouped += time.time() - start\n",
"\n",
" start = time.time()\n",
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
" torch.cuda.synchronize()\n",
" nongrouped += time.time() - start\n",
"\n",
"print('Grouped: {:.3f} us'.format(grouped * 1e6/num_profile))\n",
"print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n",
"print('Speedup: {:.3f}'.format(nongrouped / grouped))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f22fc696",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}